CINXE.COM
Collaborative Filtering for Movie Recommendations
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/structured_data/collaborative_filtering_movielens/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Collaborative Filtering for Movie Recommendations"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Collaborative Filtering for Movie Recommendations"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Collaborative Filtering for Movie Recommendations</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink active" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink2" href="/examples/structured_data/structured_data_classification_with_feature_space/">Structured data classification with FeatureSpace</a> <a class="nav-sublink2" href="/examples/structured_data/feature_space_advanced/">FeatureSpace advanced use cases</a> <a class="nav-sublink2" href="/examples/structured_data/imbalanced_classification/">Imbalanced classification: credit card fraud detection</a> <a class="nav-sublink2" href="/examples/structured_data/structured_data_classification_from_scratch/">Structured data classification from scratch</a> <a class="nav-sublink2" href="/examples/structured_data/wide_deep_cross_networks/">Structured data learning with Wide, Deep, and Cross networks</a> <a class="nav-sublink2" href="/examples/structured_data/classification_with_grn_and_vsn/">Classification with Gated Residual and Variable Selection Networks</a> <a class="nav-sublink2" href="/examples/structured_data/classification_with_tfdf/">Classification with TensorFlow Decision Forests</a> <a class="nav-sublink2" href="/examples/structured_data/deep_neural_decision_forests/">Classification with Neural Decision Forests</a> <a class="nav-sublink2" href="/examples/structured_data/tabtransformer/">Structured data learning with TabTransformer</a> <a class="nav-sublink2 active" href="/examples/structured_data/collaborative_filtering_movielens/">Collaborative Filtering for Movie Recommendations</a> <a class="nav-sublink2" href="/examples/structured_data/movielens_recommendations_transformers/">A Transformer-based recommendation system</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparameter Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> <a class="nav-link" href="/keras_cv/" role="tab" aria-selected="">KerasCV: Computer Vision Workflows</a> <a class="nav-link" href="/keras_nlp/" role="tab" aria-selected="">KerasNLP: Natural Language Workflows</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/structured_data/'>Structured Data</a> / Collaborative Filtering for Movie Recommendations </div> <div class='k-content'> <h1 id="collaborative-filtering-for-movie-recommendations">Collaborative Filtering for Movie Recommendations</h1> <p><strong>Author:</strong> <a href="https://twitter.com/sidd2006">Siddhartha Banerjee</a><br> <strong>Date created:</strong> 2020/05/24<br> <strong>Last modified:</strong> 2020/05/24<br> <strong>Description:</strong> Recommending movies using a model trained on Movielens dataset.</p> <div class='example_version_banner keras_3'>ⓘ This example uses Keras 3</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/structured_data/ipynb/collaborative_filtering_movielens.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/examples/structured_data/collaborative_filtering_movielens.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="introduction">Introduction</h2> <p>This example demonstrates <a href="https://en.wikipedia.org/wiki/Collaborative_filtering">Collaborative filtering</a> using the <a href="https://www.kaggle.com/c/movielens-100k">Movielens dataset</a> to recommend movies to users. The MovieLens ratings dataset lists the ratings given by a set of users to a set of movies. Our goal is to be able to predict ratings for movies a user has not yet watched. The movies with the highest predicted ratings can then be recommended to the user.</p> <p>The steps in the model are as follows:</p> <ol> <li>Map user ID to a "user vector" via an embedding matrix</li> <li>Map movie ID to a "movie vector" via an embedding matrix</li> <li>Compute the dot product between the user vector and movie vector, to obtain the a match score between the user and the movie (predicted rating).</li> <li>Train the embeddings via gradient descent using all known user-movie pairs.</li> </ol> <p><strong>References:</strong></p> <ul> <li><a href="https://dl.acm.org/doi/pdf/10.1145/371920.372071">Collaborative Filtering</a></li> <li><a href="https://dl.acm.org/doi/pdf/10.1145/3038912.3052569">Neural Collaborative Filtering</a></li> </ul> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">pandas</span> <span class="k">as</span> <span class="nn">pd</span> <span class="kn">from</span> <span class="nn">pathlib</span> <span class="kn">import</span> <span class="n">Path</span> <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span> <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> <span class="kn">from</span> <span class="nn">zipfile</span> <span class="kn">import</span> <span class="n">ZipFile</span> <span class="kn">import</span> <span class="nn">keras</span> <span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">layers</span> <span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">ops</span> </code></pre></div> <hr /> <h2 id="first-load-the-data-and-apply-preprocessing">First, load the data and apply preprocessing</h2> <div class="codehilite"><pre><span></span><code><span class="c1"># Download the actual data from http://files.grouplens.org/datasets/movielens/ml-latest-small.zip"</span> <span class="c1"># Use the ratings.csv file</span> <span class="n">movielens_data_file_url</span> <span class="o">=</span> <span class="p">(</span> <span class="s2">"http://files.grouplens.org/datasets/movielens/ml-latest-small.zip"</span> <span class="p">)</span> <span class="n">movielens_zipped_file</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">get_file</span><span class="p">(</span> <span class="s2">"ml-latest-small.zip"</span><span class="p">,</span> <span class="n">movielens_data_file_url</span><span class="p">,</span> <span class="n">extract</span><span class="o">=</span><span class="kc">False</span> <span class="p">)</span> <span class="n">keras_datasets_path</span> <span class="o">=</span> <span class="n">Path</span><span class="p">(</span><span class="n">movielens_zipped_file</span><span class="p">)</span><span class="o">.</span><span class="n">parents</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="n">movielens_dir</span> <span class="o">=</span> <span class="n">keras_datasets_path</span> <span class="o">/</span> <span class="s2">"ml-latest-small"</span> <span class="c1"># Only extract the data the first time the script is run.</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">movielens_dir</span><span class="o">.</span><span class="n">exists</span><span class="p">():</span> <span class="k">with</span> <span class="n">ZipFile</span><span class="p">(</span><span class="n">movielens_zipped_file</span><span class="p">,</span> <span class="s2">"r"</span><span class="p">)</span> <span class="k">as</span> <span class="nb">zip</span><span class="p">:</span> <span class="c1"># Extract files</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Extracting all the files now..."</span><span class="p">)</span> <span class="nb">zip</span><span class="o">.</span><span class="n">extractall</span><span class="p">(</span><span class="n">path</span><span class="o">=</span><span class="n">keras_datasets_path</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Done!"</span><span class="p">)</span> <span class="n">ratings_file</span> <span class="o">=</span> <span class="n">movielens_dir</span> <span class="o">/</span> <span class="s2">"ratings.csv"</span> <span class="n">df</span> <span class="o">=</span> <span class="n">pd</span><span class="o">.</span><span class="n">read_csv</span><span class="p">(</span><span class="n">ratings_file</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Downloading data from http://files.grouplens.org/datasets/movielens/ml-latest-small.zip 978202/978202 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step Extracting all the files now... Done! </code></pre></div> </div> <p>First, need to perform some preprocessing to encode users and movies as integer indices.</p> <div class="codehilite"><pre><span></span><code><span class="n">user_ids</span> <span class="o">=</span> <span class="n">df</span><span class="p">[</span><span class="s2">"userId"</span><span class="p">]</span><span class="o">.</span><span class="n">unique</span><span class="p">()</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span> <span class="n">user2user_encoded</span> <span class="o">=</span> <span class="p">{</span><span class="n">x</span><span class="p">:</span> <span class="n">i</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">x</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">user_ids</span><span class="p">)}</span> <span class="n">userencoded2user</span> <span class="o">=</span> <span class="p">{</span><span class="n">i</span><span class="p">:</span> <span class="n">x</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">x</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">user_ids</span><span class="p">)}</span> <span class="n">movie_ids</span> <span class="o">=</span> <span class="n">df</span><span class="p">[</span><span class="s2">"movieId"</span><span class="p">]</span><span class="o">.</span><span class="n">unique</span><span class="p">()</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span> <span class="n">movie2movie_encoded</span> <span class="o">=</span> <span class="p">{</span><span class="n">x</span><span class="p">:</span> <span class="n">i</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">x</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">movie_ids</span><span class="p">)}</span> <span class="n">movie_encoded2movie</span> <span class="o">=</span> <span class="p">{</span><span class="n">i</span><span class="p">:</span> <span class="n">x</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">x</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">movie_ids</span><span class="p">)}</span> <span class="n">df</span><span class="p">[</span><span class="s2">"user"</span><span class="p">]</span> <span class="o">=</span> <span class="n">df</span><span class="p">[</span><span class="s2">"userId"</span><span class="p">]</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">user2user_encoded</span><span class="p">)</span> <span class="n">df</span><span class="p">[</span><span class="s2">"movie"</span><span class="p">]</span> <span class="o">=</span> <span class="n">df</span><span class="p">[</span><span class="s2">"movieId"</span><span class="p">]</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">movie2movie_encoded</span><span class="p">)</span> <span class="n">num_users</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">user2user_encoded</span><span class="p">)</span> <span class="n">num_movies</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">movie_encoded2movie</span><span class="p">)</span> <span class="n">df</span><span class="p">[</span><span class="s2">"rating"</span><span class="p">]</span> <span class="o">=</span> <span class="n">df</span><span class="p">[</span><span class="s2">"rating"</span><span class="p">]</span><span class="o">.</span><span class="n">values</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="c1"># min and max ratings will be used to normalize the ratings later</span> <span class="n">min_rating</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">df</span><span class="p">[</span><span class="s2">"rating"</span><span class="p">])</span> <span class="n">max_rating</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">df</span><span class="p">[</span><span class="s2">"rating"</span><span class="p">])</span> <span class="nb">print</span><span class="p">(</span> <span class="s2">"Number of users: </span><span class="si">{}</span><span class="s2">, Number of Movies: </span><span class="si">{}</span><span class="s2">, Min rating: </span><span class="si">{}</span><span class="s2">, Max rating: </span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span> <span class="n">num_users</span><span class="p">,</span> <span class="n">num_movies</span><span class="p">,</span> <span class="n">min_rating</span><span class="p">,</span> <span class="n">max_rating</span> <span class="p">)</span> <span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Number of users: 610, Number of Movies: 9724, Min rating: 0.5, Max rating: 5.0 </code></pre></div> </div> <hr /> <h2 id="prepare-training-and-validation-data">Prepare training and validation data</h2> <div class="codehilite"><pre><span></span><code><span class="n">df</span> <span class="o">=</span> <span class="n">df</span><span class="o">.</span><span class="n">sample</span><span class="p">(</span><span class="n">frac</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="mi">42</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">df</span><span class="p">[[</span><span class="s2">"user"</span><span class="p">,</span> <span class="s2">"movie"</span><span class="p">]]</span><span class="o">.</span><span class="n">values</span> <span class="c1"># Normalize the targets between 0 and 1. Makes it easy to train.</span> <span class="n">y</span> <span class="o">=</span> <span class="n">df</span><span class="p">[</span><span class="s2">"rating"</span><span class="p">]</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="p">(</span><span class="n">x</span> <span class="o">-</span> <span class="n">min_rating</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="n">max_rating</span> <span class="o">-</span> <span class="n">min_rating</span><span class="p">))</span><span class="o">.</span><span class="n">values</span> <span class="c1"># Assuming training on 90% of the data and validating on 10%.</span> <span class="n">train_indices</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="mf">0.9</span> <span class="o">*</span> <span class="n">df</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="n">x_train</span><span class="p">,</span> <span class="n">x_val</span><span class="p">,</span> <span class="n">y_train</span><span class="p">,</span> <span class="n">y_val</span> <span class="o">=</span> <span class="p">(</span> <span class="n">x</span><span class="p">[:</span><span class="n">train_indices</span><span class="p">],</span> <span class="n">x</span><span class="p">[</span><span class="n">train_indices</span><span class="p">:],</span> <span class="n">y</span><span class="p">[:</span><span class="n">train_indices</span><span class="p">],</span> <span class="n">y</span><span class="p">[</span><span class="n">train_indices</span><span class="p">:],</span> <span class="p">)</span> </code></pre></div> <hr /> <h2 id="create-the-model">Create the model</h2> <p>We embed both users and movies in to 50-dimensional vectors.</p> <p>The model computes a match score between user and movie embeddings via a dot product, and adds a per-movie and per-user bias. The match score is scaled to the <code>[0, 1]</code> interval via a sigmoid (since our ratings are normalized to this range).</p> <div class="codehilite"><pre><span></span><code><span class="n">EMBEDDING_SIZE</span> <span class="o">=</span> <span class="mi">50</span> <span class="k">class</span> <span class="nc">RecommenderNet</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">num_users</span><span class="p">,</span> <span class="n">num_movies</span><span class="p">,</span> <span class="n">embedding_size</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_users</span> <span class="o">=</span> <span class="n">num_users</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_movies</span> <span class="o">=</span> <span class="n">num_movies</span> <span class="bp">self</span><span class="o">.</span><span class="n">embedding_size</span> <span class="o">=</span> <span class="n">embedding_size</span> <span class="bp">self</span><span class="o">.</span><span class="n">user_embedding</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span> <span class="n">num_users</span><span class="p">,</span> <span class="n">embedding_size</span><span class="p">,</span> <span class="n">embeddings_initializer</span><span class="o">=</span><span class="s2">"he_normal"</span><span class="p">,</span> <span class="n">embeddings_regularizer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">regularizers</span><span class="o">.</span><span class="n">l2</span><span class="p">(</span><span class="mf">1e-6</span><span class="p">),</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">user_bias</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="n">num_users</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">movie_embedding</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span> <span class="n">num_movies</span><span class="p">,</span> <span class="n">embedding_size</span><span class="p">,</span> <span class="n">embeddings_initializer</span><span class="o">=</span><span class="s2">"he_normal"</span><span class="p">,</span> <span class="n">embeddings_regularizer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">regularizers</span><span class="o">.</span><span class="n">l2</span><span class="p">(</span><span class="mf">1e-6</span><span class="p">),</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">movie_bias</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="n">num_movies</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="n">user_vector</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">user_embedding</span><span class="p">(</span><span class="n">inputs</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">])</span> <span class="n">user_bias</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">user_bias</span><span class="p">(</span><span class="n">inputs</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">])</span> <span class="n">movie_vector</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">movie_embedding</span><span class="p">(</span><span class="n">inputs</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">])</span> <span class="n">movie_bias</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">movie_bias</span><span class="p">(</span><span class="n">inputs</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">])</span> <span class="n">dot_user_movie</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">tensordot</span><span class="p">(</span><span class="n">user_vector</span><span class="p">,</span> <span class="n">movie_vector</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span> <span class="c1"># Add all the components (including bias)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">dot_user_movie</span> <span class="o">+</span> <span class="n">user_bias</span> <span class="o">+</span> <span class="n">movie_bias</span> <span class="c1"># The sigmoid activation forces the rating to between 0 and 1</span> <span class="k">return</span> <span class="n">ops</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">RecommenderNet</span><span class="p">(</span><span class="n">num_users</span><span class="p">,</span> <span class="n">num_movies</span><span class="p">,</span> <span class="n">EMBEDDING_SIZE</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">loss</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">BinaryCrossentropy</span><span class="p">(),</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">learning_rate</span><span class="o">=</span><span class="mf">0.001</span><span class="p">),</span> <span class="p">)</span> </code></pre></div> <hr /> <h2 id="train-the-model-based-on-the-data-split">Train the model based on the data split</h2> <div class="codehilite"><pre><span></span><code><span class="n">history</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span> <span class="n">x</span><span class="o">=</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y</span><span class="o">=</span><span class="n">y_train</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">verbose</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="p">(</span><span class="n">x_val</span><span class="p">,</span> <span class="n">y_val</span><span class="p">),</span> <span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 1/5 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 2s 1ms/step - loss: 0.6591 - val_loss: 0.6201 Epoch 2/5 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 894us/step - loss: 0.6159 - val_loss: 0.6191 Epoch 3/5 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 977us/step - loss: 0.6093 - val_loss: 0.6138 Epoch 4/5 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 865us/step - loss: 0.6100 - val_loss: 0.6123 Epoch 5/5 1418/1418 ━━━━━━━━━━━━━━━━━━━━ 1s 854us/step - loss: 0.6072 - val_loss: 0.6121 </code></pre></div> </div> <hr /> <h2 id="plot-training-and-validation-loss">Plot training and validation loss</h2> <div class="codehilite"><pre><span></span><code><span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">history</span><span class="o">.</span><span class="n">history</span><span class="p">[</span><span class="s2">"loss"</span><span class="p">])</span> <span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">history</span><span class="o">.</span><span class="n">history</span><span class="p">[</span><span class="s2">"val_loss"</span><span class="p">])</span> <span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="s2">"model loss"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">ylabel</span><span class="p">(</span><span class="s2">"loss"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s2">"epoch"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">legend</span><span class="p">([</span><span class="s2">"train"</span><span class="p">,</span> <span class="s2">"test"</span><span class="p">],</span> <span class="n">loc</span><span class="o">=</span><span class="s2">"upper left"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> </code></pre></div> <p><img alt="png" src="/img/examples/structured_data/collaborative_filtering_movielens/collaborative_filtering_movielens_14_0.png" /></p> <hr /> <h2 id="show-top-10-movie-recommendations-to-a-user">Show top 10 movie recommendations to a user</h2> <div class="codehilite"><pre><span></span><code><span class="n">movie_df</span> <span class="o">=</span> <span class="n">pd</span><span class="o">.</span><span class="n">read_csv</span><span class="p">(</span><span class="n">movielens_dir</span> <span class="o">/</span> <span class="s2">"movies.csv"</span><span class="p">)</span> <span class="c1"># Let us get a user and see the top recommendations.</span> <span class="n">user_id</span> <span class="o">=</span> <span class="n">df</span><span class="o">.</span><span class="n">userId</span><span class="o">.</span><span class="n">sample</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">iloc</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="n">movies_watched_by_user</span> <span class="o">=</span> <span class="n">df</span><span class="p">[</span><span class="n">df</span><span class="o">.</span><span class="n">userId</span> <span class="o">==</span> <span class="n">user_id</span><span class="p">]</span> <span class="n">movies_not_watched</span> <span class="o">=</span> <span class="n">movie_df</span><span class="p">[</span> <span class="o">~</span><span class="n">movie_df</span><span class="p">[</span><span class="s2">"movieId"</span><span class="p">]</span><span class="o">.</span><span class="n">isin</span><span class="p">(</span><span class="n">movies_watched_by_user</span><span class="o">.</span><span class="n">movieId</span><span class="o">.</span><span class="n">values</span><span class="p">)</span> <span class="p">][</span><span class="s2">"movieId"</span><span class="p">]</span> <span class="n">movies_not_watched</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span> <span class="nb">set</span><span class="p">(</span><span class="n">movies_not_watched</span><span class="p">)</span><span class="o">.</span><span class="n">intersection</span><span class="p">(</span><span class="nb">set</span><span class="p">(</span><span class="n">movie2movie_encoded</span><span class="o">.</span><span class="n">keys</span><span class="p">()))</span> <span class="p">)</span> <span class="n">movies_not_watched</span> <span class="o">=</span> <span class="p">[[</span><span class="n">movie2movie_encoded</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">x</span><span class="p">)]</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">movies_not_watched</span><span class="p">]</span> <span class="n">user_encoder</span> <span class="o">=</span> <span class="n">user2user_encoded</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">user_id</span><span class="p">)</span> <span class="n">user_movie_array</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">hstack</span><span class="p">(</span> <span class="p">([[</span><span class="n">user_encoder</span><span class="p">]]</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="n">movies_not_watched</span><span class="p">),</span> <span class="n">movies_not_watched</span><span class="p">)</span> <span class="p">)</span> <span class="n">ratings</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">user_movie_array</span><span class="p">)</span><span class="o">.</span><span class="n">flatten</span><span class="p">()</span> <span class="n">top_ratings_indices</span> <span class="o">=</span> <span class="n">ratings</span><span class="o">.</span><span class="n">argsort</span><span class="p">()[</span><span class="o">-</span><span class="mi">10</span><span class="p">:][::</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="n">recommended_movie_ids</span> <span class="o">=</span> <span class="p">[</span> <span class="n">movie_encoded2movie</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">movies_not_watched</span><span class="p">[</span><span class="n">x</span><span class="p">][</span><span class="mi">0</span><span class="p">])</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">top_ratings_indices</span> <span class="p">]</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Showing recommendations for user: </span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">user_id</span><span class="p">))</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"===="</span> <span class="o">*</span> <span class="mi">9</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Movies with high ratings from user"</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"----"</span> <span class="o">*</span> <span class="mi">8</span><span class="p">)</span> <span class="n">top_movies_user</span> <span class="o">=</span> <span class="p">(</span> <span class="n">movies_watched_by_user</span><span class="o">.</span><span class="n">sort_values</span><span class="p">(</span><span class="n">by</span><span class="o">=</span><span class="s2">"rating"</span><span class="p">,</span> <span class="n">ascending</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> <span class="o">.</span><span class="n">head</span><span class="p">(</span><span class="mi">5</span><span class="p">)</span> <span class="o">.</span><span class="n">movieId</span><span class="o">.</span><span class="n">values</span> <span class="p">)</span> <span class="n">movie_df_rows</span> <span class="o">=</span> <span class="n">movie_df</span><span class="p">[</span><span class="n">movie_df</span><span class="p">[</span><span class="s2">"movieId"</span><span class="p">]</span><span class="o">.</span><span class="n">isin</span><span class="p">(</span><span class="n">top_movies_user</span><span class="p">)]</span> <span class="k">for</span> <span class="n">row</span> <span class="ow">in</span> <span class="n">movie_df_rows</span><span class="o">.</span><span class="n">itertuples</span><span class="p">():</span> <span class="nb">print</span><span class="p">(</span><span class="n">row</span><span class="o">.</span><span class="n">title</span><span class="p">,</span> <span class="s2">":"</span><span class="p">,</span> <span class="n">row</span><span class="o">.</span><span class="n">genres</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"----"</span> <span class="o">*</span> <span class="mi">8</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Top 10 movie recommendations"</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"----"</span> <span class="o">*</span> <span class="mi">8</span><span class="p">)</span> <span class="n">recommended_movies</span> <span class="o">=</span> <span class="n">movie_df</span><span class="p">[</span><span class="n">movie_df</span><span class="p">[</span><span class="s2">"movieId"</span><span class="p">]</span><span class="o">.</span><span class="n">isin</span><span class="p">(</span><span class="n">recommended_movie_ids</span><span class="p">)]</span> <span class="k">for</span> <span class="n">row</span> <span class="ow">in</span> <span class="n">recommended_movies</span><span class="o">.</span><span class="n">itertuples</span><span class="p">():</span> <span class="nb">print</span><span class="p">(</span><span class="n">row</span><span class="o">.</span><span class="n">title</span><span class="p">,</span> <span class="s2">":"</span><span class="p">,</span> <span class="n">row</span><span class="o">.</span><span class="n">genres</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 272/272 ━━━━━━━━━━━━━━━━━━━━ 0s 714us/step Showing recommendations for user: 249 ==================================== Movies with high ratings from user -------------------------------- Fight Club (1999) : Action|Crime|Drama|Thriller Serenity (2005) : Action|Adventure|Sci-Fi Departed, The (2006) : Crime|Drama|Thriller Prisoners (2013) : Drama|Mystery|Thriller Arrival (2016) : Sci-Fi -------------------------------- Top 10 movie recommendations -------------------------------- In the Name of the Father (1993) : Drama Monty Python and the Holy Grail (1975) : Adventure|Comedy|Fantasy Princess Bride, The (1987) : Action|Adventure|Comedy|Fantasy|Romance Lawrence of Arabia (1962) : Adventure|Drama|War Apocalypse Now (1979) : Action|Drama|War Full Metal Jacket (1987) : Drama|War Amadeus (1984) : Drama Glory (1989) : Drama|War Chinatown (1974) : Crime|Film-Noir|Mystery|Thriller City of God (Cidade de Deus) (2002) : Action|Adventure|Crime|Drama|Thriller </code></pre></div> </div> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#collaborative-filtering-for-movie-recommendations'>Collaborative Filtering for Movie Recommendations</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#first-load-the-data-and-apply-preprocessing'>First, load the data and apply preprocessing</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#prepare-training-and-validation-data'>Prepare training and validation data</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#create-the-model'>Create the model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#train-the-model-based-on-the-data-split'>Train the model based on the data split</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#plot-training-and-validation-loss'>Plot training and validation loss</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#show-top-10-movie-recommendations-to-a-user'>Show top 10 movie recommendations to a user</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>