CINXE.COM
Using pre-trained word embeddings
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/nlp/pretrained_word_embeddings/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Using pre-trained word embeddings"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Using pre-trained word embeddings"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Using pre-trained word embeddings</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink active" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink2" href="/examples/nlp/text_classification_from_scratch/">Text classification from scratch</a> <a class="nav-sublink2" href="/examples/nlp/active_learning_review_classification/">Review Classification using Active Learning</a> <a class="nav-sublink2" href="/examples/nlp/fnet_classification_with_keras_hub/">Text Classification using FNet</a> <a class="nav-sublink2" href="/examples/nlp/multi_label_classification/">Large-scale multi-label text classification</a> <a class="nav-sublink2" href="/examples/nlp/text_classification_with_transformer/">Text classification with Transformer</a> <a class="nav-sublink2" href="/examples/nlp/text_classification_with_switch_transformer/">Text classification with Switch Transformer</a> <a class="nav-sublink2" href="/examples/nlp/tweet-classification-using-tfdf/">Text classification using Decision Forests and pretrained embeddings</a> <a class="nav-sublink2 active" href="/examples/nlp/pretrained_word_embeddings/">Using pre-trained word embeddings</a> <a class="nav-sublink2" href="/examples/nlp/bidirectional_lstm_imdb/">Bidirectional LSTM on IMDB</a> <a class="nav-sublink2" href="/examples/nlp/data_parallel_training_with_keras_hub/">Data Parallel Training with KerasHub and tf.distribute</a> <a class="nav-sublink2" href="/examples/nlp/neural_machine_translation_with_keras_hub/">English-to-Spanish translation with KerasHub</a> <a class="nav-sublink2" href="/examples/nlp/neural_machine_translation_with_transformer/">English-to-Spanish translation with a sequence-to-sequence Transformer</a> <a class="nav-sublink2" href="/examples/nlp/lstm_seq2seq/">Character-level recurrent sequence-to-sequence model</a> <a class="nav-sublink2" href="/examples/nlp/multimodal_entailment/">Multimodal entailment</a> <a class="nav-sublink2" href="/examples/nlp/ner_transformers/">Named Entity Recognition using Transformers</a> <a class="nav-sublink2" href="/examples/nlp/text_extraction_with_bert/">Text Extraction with BERT</a> <a class="nav-sublink2" href="/examples/nlp/addition_rnn/">Sequence to sequence learning for performing number addition</a> <a class="nav-sublink2" href="/examples/nlp/semantic_similarity_with_keras_hub/">Semantic Similarity with KerasHub</a> <a class="nav-sublink2" href="/examples/nlp/semantic_similarity_with_bert/">Semantic Similarity with BERT</a> <a class="nav-sublink2" href="/examples/nlp/sentence_embeddings_with_sbert/">Sentence embeddings using Siamese RoBERTa-networks</a> <a class="nav-sublink2" href="/examples/nlp/masked_language_modeling/">End-to-end Masked Language Modeling with BERT</a> <a class="nav-sublink2" href="/examples/nlp/abstractive_summarization_with_bart/">Abstractive Text Summarization with BART</a> <a class="nav-sublink2" href="/examples/nlp/pretraining_BERT/">Pretraining BERT with Hugging Face Transformers</a> <a class="nav-sublink2" href="/examples/nlp/parameter_efficient_finetuning_of_gpt2_with_lora/">Parameter-efficient fine-tuning of GPT-2 with LoRA</a> <a class="nav-sublink2" href="/examples/nlp/mlm_training_tpus/">Training a language model from scratch with 🤗 Transformers and TPUs</a> <a class="nav-sublink2" href="/examples/nlp/multiple_choice_task_with_transfer_learning/">MultipleChoice Task with Transfer Learning</a> <a class="nav-sublink2" href="/examples/nlp/question_answering/">Question Answering with Hugging Face Transformers</a> <a class="nav-sublink2" href="/examples/nlp/t5_hf_summarization/">Abstractive Summarization with Hugging Face Transformers</a> <a class="nav-sublink" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparam Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/nlp/'>Natural Language Processing</a> / Using pre-trained word embeddings </div> <div class='k-content'> <h1 id="using-pretrained-word-embeddings">Using pre-trained word embeddings</h1> <p><strong>Author:</strong> <a href="https://twitter.com/fchollet">fchollet</a><br> <strong>Date created:</strong> 2020/05/05<br> <strong>Last modified:</strong> 2020/05/05<br> <strong>Description:</strong> Text classification on the Newsgroup20 dataset using pre-trained GloVe word embeddings.</p> <div class='example_version_banner keras_3'>ⓘ This example uses Keras 3</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/nlp/ipynb/pretrained_word_embeddings.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/examples/nlp/pretrained_word_embeddings.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="setup">Setup</h2> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">os</span> <span class="c1"># Only the TensorFlow backend supports string inputs.</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s2">"KERAS_BACKEND"</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"tensorflow"</span> <span class="kn">import</span> <span class="nn">pathlib</span> <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> <span class="kn">import</span> <span class="nn">tensorflow.data</span> <span class="k">as</span> <span class="nn">tf_data</span> <span class="kn">import</span> <span class="nn">keras</span> <span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">layers</span> </code></pre></div> <hr /> <h2 id="introduction">Introduction</h2> <p>In this example, we show how to train a text classification model that uses pre-trained word embeddings.</p> <p>We'll work with the Newsgroup20 dataset, a set of 20,000 message board messages belonging to 20 different topic categories.</p> <p>For the pre-trained word embeddings, we'll use <a href="http://nlp.stanford.edu/projects/glove/">GloVe embeddings</a>.</p> <hr /> <h2 id="download-the-newsgroup20-data">Download the Newsgroup20 data</h2> <div class="codehilite"><pre><span></span><code><span class="n">data_path</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">get_file</span><span class="p">(</span> <span class="s2">"news20.tar.gz"</span><span class="p">,</span> <span class="s2">"http://www.cs.cmu.edu/afs/cs.cmu.edu/project/theo-20/www/data/news20.tar.gz"</span><span class="p">,</span> <span class="n">untar</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="p">)</span> </code></pre></div> <hr /> <h2 id="lets-take-a-look-at-the-data">Let's take a look at the data</h2> <div class="codehilite"><pre><span></span><code><span class="n">data_dir</span> <span class="o">=</span> <span class="n">pathlib</span><span class="o">.</span><span class="n">Path</span><span class="p">(</span><span class="n">data_path</span><span class="p">)</span><span class="o">.</span><span class="n">parent</span> <span class="o">/</span> <span class="s2">"20_newsgroup"</span> <span class="n">dirnames</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">listdir</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Number of directories:"</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">dirnames</span><span class="p">))</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Directory names:"</span><span class="p">,</span> <span class="n">dirnames</span><span class="p">)</span> <span class="n">fnames</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">listdir</span><span class="p">(</span><span class="n">data_dir</span> <span class="o">/</span> <span class="s2">"comp.graphics"</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Number of files in comp.graphics:"</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">fnames</span><span class="p">))</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Some example filenames:"</span><span class="p">,</span> <span class="n">fnames</span><span class="p">[:</span><span class="mi">5</span><span class="p">])</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Number of directories: 20 Directory names: ['comp.sys.ibm.pc.hardware', 'comp.os.ms-windows.misc', 'comp.windows.x', 'sci.space', 'sci.crypt', 'sci.med', 'alt.atheism', 'rec.autos', 'rec.sport.hockey', 'talk.politics.misc', 'talk.politics.mideast', 'rec.motorcycles', 'talk.politics.guns', 'misc.forsale', 'sci.electronics', 'talk.religion.misc', 'comp.graphics', 'soc.religion.christian', 'comp.sys.mac.hardware', 'rec.sport.baseball'] Number of files in comp.graphics: 1000 Some example filenames: ['39638', '38747', '38242', '39057', '39031'] </code></pre></div> </div> <p>Here's a example of what one file contains:</p> <div class="codehilite"><pre><span></span><code><span class="nb">print</span><span class="p">(</span><span class="nb">open</span><span class="p">(</span><span class="n">data_dir</span> <span class="o">/</span> <span class="s2">"comp.graphics"</span> <span class="o">/</span> <span class="s2">"38987"</span><span class="p">)</span><span class="o">.</span><span class="n">read</span><span class="p">())</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Newsgroups: comp.graphics Path: cantaloupe.srv.cs.cmu.edu!das-news.harvard.edu!noc.near.net!howland.reston.ans.net!agate!dog.ee.lbl.gov!network.ucsd.edu!usc!rpi!nason110.its.rpi.edu!mabusj From: mabusj@nason110.its.rpi.edu (Jasen M. Mabus) Subject: Looking for Brain in CAD Message-ID: <c285m+p@rpi.edu> Nntp-Posting-Host: nason110.its.rpi.edu Reply-To: mabusj@rpi.edu Organization: Rensselaer Polytechnic Institute, Troy, NY. Date: Thu, 29 Apr 1993 23:27:20 GMT Lines: 7 </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Jasen Mabus RPI student </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> I am looking for a hman brain in any CAD (.dxf,.cad,.iges,.cgm,etc.) or picture (.gif,.jpg,.ras,etc.) format for an animation demonstration. If any has or knows of a location please reply by e-mail to mabusj@rpi.edu. </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Thank you in advance, Jasen Mabus </code></pre></div> </div> <p>As you can see, there are header lines that are leaking the file's category, either explicitly (the first line is literally the category name), or implicitly, e.g. via the <code>Organization</code> filed. Let's get rid of the headers:</p> <div class="codehilite"><pre><span></span><code><span class="n">samples</span> <span class="o">=</span> <span class="p">[]</span> <span class="n">labels</span> <span class="o">=</span> <span class="p">[]</span> <span class="n">class_names</span> <span class="o">=</span> <span class="p">[]</span> <span class="n">class_index</span> <span class="o">=</span> <span class="mi">0</span> <span class="k">for</span> <span class="n">dirname</span> <span class="ow">in</span> <span class="nb">sorted</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">listdir</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)):</span> <span class="n">class_names</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">dirname</span><span class="p">)</span> <span class="n">dirpath</span> <span class="o">=</span> <span class="n">data_dir</span> <span class="o">/</span> <span class="n">dirname</span> <span class="n">fnames</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">listdir</span><span class="p">(</span><span class="n">dirpath</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Processing </span><span class="si">%s</span><span class="s2">, </span><span class="si">%d</span><span class="s2"> files found"</span> <span class="o">%</span> <span class="p">(</span><span class="n">dirname</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">fnames</span><span class="p">)))</span> <span class="k">for</span> <span class="n">fname</span> <span class="ow">in</span> <span class="n">fnames</span><span class="p">:</span> <span class="n">fpath</span> <span class="o">=</span> <span class="n">dirpath</span> <span class="o">/</span> <span class="n">fname</span> <span class="n">f</span> <span class="o">=</span> <span class="nb">open</span><span class="p">(</span><span class="n">fpath</span><span class="p">,</span> <span class="n">encoding</span><span class="o">=</span><span class="s2">"latin-1"</span><span class="p">)</span> <span class="n">content</span> <span class="o">=</span> <span class="n">f</span><span class="o">.</span><span class="n">read</span><span class="p">()</span> <span class="n">lines</span> <span class="o">=</span> <span class="n">content</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">"</span><span class="se">\n</span><span class="s2">"</span><span class="p">)</span> <span class="n">lines</span> <span class="o">=</span> <span class="n">lines</span><span class="p">[</span><span class="mi">10</span><span class="p">:]</span> <span class="n">content</span> <span class="o">=</span> <span class="s2">"</span><span class="se">\n</span><span class="s2">"</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">lines</span><span class="p">)</span> <span class="n">samples</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">content</span><span class="p">)</span> <span class="n">labels</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">class_index</span><span class="p">)</span> <span class="n">class_index</span> <span class="o">+=</span> <span class="mi">1</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Classes:"</span><span class="p">,</span> <span class="n">class_names</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Number of samples:"</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">samples</span><span class="p">))</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Processing alt.atheism, 1000 files found Processing comp.graphics, 1000 files found Processing comp.os.ms-windows.misc, 1000 files found Processing comp.sys.ibm.pc.hardware, 1000 files found Processing comp.sys.mac.hardware, 1000 files found Processing comp.windows.x, 1000 files found Processing misc.forsale, 1000 files found Processing rec.autos, 1000 files found Processing rec.motorcycles, 1000 files found Processing rec.sport.baseball, 1000 files found Processing rec.sport.hockey, 1000 files found Processing sci.crypt, 1000 files found Processing sci.electronics, 1000 files found Processing sci.med, 1000 files found Processing sci.space, 1000 files found Processing soc.religion.christian, 997 files found Processing talk.politics.guns, 1000 files found Processing talk.politics.mideast, 1000 files found Processing talk.politics.misc, 1000 files found Processing talk.religion.misc, 1000 files found Classes: ['alt.atheism', 'comp.graphics', 'comp.os.ms-windows.misc', 'comp.sys.ibm.pc.hardware', 'comp.sys.mac.hardware', 'comp.windows.x', 'misc.forsale', 'rec.autos', 'rec.motorcycles', 'rec.sport.baseball', 'rec.sport.hockey', 'sci.crypt', 'sci.electronics', 'sci.med', 'sci.space', 'soc.religion.christian', 'talk.politics.guns', 'talk.politics.mideast', 'talk.politics.misc', 'talk.religion.misc'] Number of samples: 19997 </code></pre></div> </div> <p>There's actually one category that doesn't have the expected number of files, but the difference is small enough that the problem remains a balanced classification problem.</p> <hr /> <h2 id="shuffle-and-split-the-data-into-training-amp-validation-sets">Shuffle and split the data into training & validation sets</h2> <div class="codehilite"><pre><span></span><code><span class="c1"># Shuffle the data</span> <span class="n">seed</span> <span class="o">=</span> <span class="mi">1337</span> <span class="n">rng</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">RandomState</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span> <span class="n">rng</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="n">samples</span><span class="p">)</span> <span class="n">rng</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">RandomState</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span> <span class="n">rng</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="n">labels</span><span class="p">)</span> <span class="c1"># Extract a training & validation split</span> <span class="n">validation_split</span> <span class="o">=</span> <span class="mf">0.2</span> <span class="n">num_validation_samples</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">validation_split</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="n">samples</span><span class="p">))</span> <span class="n">train_samples</span> <span class="o">=</span> <span class="n">samples</span><span class="p">[:</span><span class="o">-</span><span class="n">num_validation_samples</span><span class="p">]</span> <span class="n">val_samples</span> <span class="o">=</span> <span class="n">samples</span><span class="p">[</span><span class="o">-</span><span class="n">num_validation_samples</span><span class="p">:]</span> <span class="n">train_labels</span> <span class="o">=</span> <span class="n">labels</span><span class="p">[:</span><span class="o">-</span><span class="n">num_validation_samples</span><span class="p">]</span> <span class="n">val_labels</span> <span class="o">=</span> <span class="n">labels</span><span class="p">[</span><span class="o">-</span><span class="n">num_validation_samples</span><span class="p">:]</span> </code></pre></div> <hr /> <h2 id="create-a-vocabulary-index">Create a vocabulary index</h2> <p>Let's use the <code>TextVectorization</code> to index the vocabulary found in the dataset. Later, we'll use the same layer instance to vectorize the samples.</p> <p>Our layer will only consider the top 20,000 words, and will truncate or pad sequences to be actually 200 tokens long.</p> <div class="codehilite"><pre><span></span><code><span class="n">vectorizer</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">TextVectorization</span><span class="p">(</span><span class="n">max_tokens</span><span class="o">=</span><span class="mi">20000</span><span class="p">,</span> <span class="n">output_sequence_length</span><span class="o">=</span><span class="mi">200</span><span class="p">)</span> <span class="n">text_ds</span> <span class="o">=</span> <span class="n">tf_data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">from_tensor_slices</span><span class="p">(</span><span class="n">train_samples</span><span class="p">)</span><span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="mi">128</span><span class="p">)</span> <span class="n">vectorizer</span><span class="o">.</span><span class="n">adapt</span><span class="p">(</span><span class="n">text_ds</span><span class="p">)</span> </code></pre></div> <p>You can retrieve the computed vocabulary used via <code>vectorizer.get_vocabulary()</code>. Let's print the top 5 words:</p> <div class="codehilite"><pre><span></span><code><span class="n">vectorizer</span><span class="o">.</span><span class="n">get_vocabulary</span><span class="p">()[:</span><span class="mi">5</span><span class="p">]</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>['', '[UNK]', 'the', 'to', 'of'] </code></pre></div> </div> <p>Let's vectorize a test sentence:</p> <div class="codehilite"><pre><span></span><code><span class="n">output</span> <span class="o">=</span> <span class="n">vectorizer</span><span class="p">([[</span><span class="s2">"the cat sat on the mat"</span><span class="p">]])</span> <span class="n">output</span><span class="o">.</span><span class="n">numpy</span><span class="p">()[</span><span class="mi">0</span><span class="p">,</span> <span class="p">:</span><span class="mi">6</span><span class="p">]</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>array([ 2, 3480, 1818, 15, 2, 5830]) </code></pre></div> </div> <p>As you can see, "the" gets represented as "2". Why not 0, given that "the" was the first word in the vocabulary? That's because index 0 is reserved for padding and index 1 is reserved for "out of vocabulary" tokens.</p> <p>Here's a dict mapping words to their indices:</p> <div class="codehilite"><pre><span></span><code><span class="n">voc</span> <span class="o">=</span> <span class="n">vectorizer</span><span class="o">.</span><span class="n">get_vocabulary</span><span class="p">()</span> <span class="n">word_index</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="n">voc</span><span class="p">,</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">voc</span><span class="p">))))</span> </code></pre></div> <p>As you can see, we obtain the same encoding as above for our test sentence:</p> <div class="codehilite"><pre><span></span><code><span class="n">test</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"the"</span><span class="p">,</span> <span class="s2">"cat"</span><span class="p">,</span> <span class="s2">"sat"</span><span class="p">,</span> <span class="s2">"on"</span><span class="p">,</span> <span class="s2">"the"</span><span class="p">,</span> <span class="s2">"mat"</span><span class="p">]</span> <span class="p">[</span><span class="n">word_index</span><span class="p">[</span><span class="n">w</span><span class="p">]</span> <span class="k">for</span> <span class="n">w</span> <span class="ow">in</span> <span class="n">test</span><span class="p">]</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>[2, 3480, 1818, 15, 2, 5830] </code></pre></div> </div> <hr /> <h2 id="load-pretrained-word-embeddings">Load pre-trained word embeddings</h2> <p>Let's download pre-trained GloVe embeddings (a 822M zip file).</p> <p>You'll need to run the following commands:</p> <div class="codehilite"><pre><span></span><code><span class="err">!</span><span class="n">wget</span> <span class="n">https</span><span class="p">:</span><span class="o">//</span><span class="n">downloads</span><span class="o">.</span><span class="n">cs</span><span class="o">.</span><span class="n">stanford</span><span class="o">.</span><span class="n">edu</span><span class="o">/</span><span class="n">nlp</span><span class="o">/</span><span class="n">data</span><span class="o">/</span><span class="n">glove</span><span class="mf">.6</span><span class="n">B</span><span class="o">.</span><span class="n">zip</span> <span class="err">!</span><span class="n">unzip</span> <span class="o">-</span><span class="n">q</span> <span class="n">glove</span><span class="mf">.6</span><span class="n">B</span><span class="o">.</span><span class="n">zip</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>--2023-11-19 22:45:27-- https://downloads.cs.stanford.edu/nlp/data/glove.6B.zip Resolving downloads.cs.stanford.edu (downloads.cs.stanford.edu)... 171.64.64.22 Connecting to downloads.cs.stanford.edu (downloads.cs.stanford.edu)|171.64.64.22|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 862182613 (822M) [application/zip] Saving to: ‘glove.6B.zip’ </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>glove.6B.zip 100%[===================>] 822.24M 5.05MB/s in 2m 39s </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>2023-11-19 22:48:06 (5.19 MB/s) - ‘glove.6B.zip’ saved [862182613/862182613] </code></pre></div> </div> <p>The archive contains text-encoded vectors of various sizes: 50-dimensional, 100-dimensional, 200-dimensional, 300-dimensional. We'll use the 100D ones.</p> <p>Let's make a dict mapping words (strings) to their NumPy vector representation:</p> <div class="codehilite"><pre><span></span><code><span class="n">path_to_glove_file</span> <span class="o">=</span> <span class="s2">"glove.6B.100d.txt"</span> <span class="n">embeddings_index</span> <span class="o">=</span> <span class="p">{}</span> <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">path_to_glove_file</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span> <span class="k">for</span> <span class="n">line</span> <span class="ow">in</span> <span class="n">f</span><span class="p">:</span> <span class="n">word</span><span class="p">,</span> <span class="n">coefs</span> <span class="o">=</span> <span class="n">line</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">maxsplit</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="n">coefs</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">fromstring</span><span class="p">(</span><span class="n">coefs</span><span class="p">,</span> <span class="s2">"f"</span><span class="p">,</span> <span class="n">sep</span><span class="o">=</span><span class="s2">" "</span><span class="p">)</span> <span class="n">embeddings_index</span><span class="p">[</span><span class="n">word</span><span class="p">]</span> <span class="o">=</span> <span class="n">coefs</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Found </span><span class="si">%s</span><span class="s2"> word vectors."</span> <span class="o">%</span> <span class="nb">len</span><span class="p">(</span><span class="n">embeddings_index</span><span class="p">))</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Found 400000 word vectors. </code></pre></div> </div> <p>Now, let's prepare a corresponding embedding matrix that we can use in a Keras <code>Embedding</code> layer. It's a simple NumPy matrix where entry at index <code>i</code> is the pre-trained vector for the word of index <code>i</code> in our <code>vectorizer</code>'s vocabulary.</p> <div class="codehilite"><pre><span></span><code><span class="n">num_tokens</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">voc</span><span class="p">)</span> <span class="o">+</span> <span class="mi">2</span> <span class="n">embedding_dim</span> <span class="o">=</span> <span class="mi">100</span> <span class="n">hits</span> <span class="o">=</span> <span class="mi">0</span> <span class="n">misses</span> <span class="o">=</span> <span class="mi">0</span> <span class="c1"># Prepare embedding matrix</span> <span class="n">embedding_matrix</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">num_tokens</span><span class="p">,</span> <span class="n">embedding_dim</span><span class="p">))</span> <span class="k">for</span> <span class="n">word</span><span class="p">,</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">word_index</span><span class="o">.</span><span class="n">items</span><span class="p">():</span> <span class="n">embedding_vector</span> <span class="o">=</span> <span class="n">embeddings_index</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">word</span><span class="p">)</span> <span class="k">if</span> <span class="n">embedding_vector</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> <span class="c1"># Words not found in embedding index will be all-zeros.</span> <span class="c1"># This includes the representation for "padding" and "OOV"</span> <span class="n">embedding_matrix</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">embedding_vector</span> <span class="n">hits</span> <span class="o">+=</span> <span class="mi">1</span> <span class="k">else</span><span class="p">:</span> <span class="n">misses</span> <span class="o">+=</span> <span class="mi">1</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Converted </span><span class="si">%d</span><span class="s2"> words (</span><span class="si">%d</span><span class="s2"> misses)"</span> <span class="o">%</span> <span class="p">(</span><span class="n">hits</span><span class="p">,</span> <span class="n">misses</span><span class="p">))</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Converted 18021 words (1979 misses) </code></pre></div> </div> <p>Next, we load the pre-trained word embeddings matrix into an <code>Embedding</code> layer.</p> <p>Note that we set <code>trainable=False</code> so as to keep the embeddings fixed (we don't want to update them during training).</p> <div class="codehilite"><pre><span></span><code><span class="kn">from</span> <span class="nn">keras.layers</span> <span class="kn">import</span> <span class="n">Embedding</span> <span class="n">embedding_layer</span> <span class="o">=</span> <span class="n">Embedding</span><span class="p">(</span> <span class="n">num_tokens</span><span class="p">,</span> <span class="n">embedding_dim</span><span class="p">,</span> <span class="n">trainable</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="p">)</span> <span class="n">embedding_layer</span><span class="o">.</span><span class="n">build</span><span class="p">((</span><span class="mi">1</span><span class="p">,))</span> <span class="n">embedding_layer</span><span class="o">.</span><span class="n">set_weights</span><span class="p">([</span><span class="n">embedding_matrix</span><span class="p">])</span> </code></pre></div> <hr /> <h2 id="build-the-model">Build the model</h2> <p>A simple 1D convnet with global max pooling and a classifier at the end.</p> <div class="codehilite"><pre><span></span><code><span class="n">int_sequences_input</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="kc">None</span><span class="p">,),</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"int32"</span><span class="p">)</span> <span class="n">embedded_sequences</span> <span class="o">=</span> <span class="n">embedding_layer</span><span class="p">(</span><span class="n">int_sequences_input</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv1D</span><span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)(</span><span class="n">embedded_sequences</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">MaxPooling1D</span><span class="p">(</span><span class="mi">5</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv1D</span><span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">MaxPooling1D</span><span class="p">(</span><span class="mi">5</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv1D</span><span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">GlobalMaxPooling1D</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="mf">0.5</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">preds</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">class_names</span><span class="p">),</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"softmax"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">int_sequences_input</span><span class="p">,</span> <span class="n">preds</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">summary</span><span class="p">()</span> </code></pre></div> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold">Model: "functional_1"</span> </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ ┃<span style="font-weight: bold"> Layer (type) </span>┃<span style="font-weight: bold"> Output Shape </span>┃<span style="font-weight: bold"> Param # </span>┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ │ input_layer (<span style="color: #0087ff; text-decoration-color: #0087ff">InputLayer</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ embedding (<span style="color: #0087ff; text-decoration-color: #0087ff">Embedding</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">100</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">2,000,200</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv1d (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv1D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">64,128</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ max_pooling1d (<span style="color: #0087ff; text-decoration-color: #0087ff">MaxPooling1D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv1d_1 (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv1D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">82,048</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ max_pooling1d_1 (<span style="color: #0087ff; text-decoration-color: #0087ff">MaxPooling1D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv1d_2 (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv1D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">82,048</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ global_max_pooling1d │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">GlobalMaxPooling1D</span>) │ │ │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ dense (<span style="color: #0087ff; text-decoration-color: #0087ff">Dense</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">16,512</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ dropout (<span style="color: #0087ff; text-decoration-color: #0087ff">Dropout</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ dense_1 (<span style="color: #0087ff; text-decoration-color: #0087ff">Dense</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">20</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">2,580</span> │ └─────────────────────────────────┴───────────────────────────┴────────────┘ </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Total params: </span><span style="color: #00af00; text-decoration-color: #00af00">2,247,516</span> (8.57 MB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">2,247,516</span> (8.57 MB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Non-trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">0</span> (0.00 B) </pre> <hr /> <h2 id="train-the-model">Train the model</h2> <p>First, convert our list-of-strings data to NumPy arrays of integer indices. The arrays are right-padded.</p> <div class="codehilite"><pre><span></span><code><span class="n">x_train</span> <span class="o">=</span> <span class="n">vectorizer</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="n">s</span><span class="p">]</span> <span class="k">for</span> <span class="n">s</span> <span class="ow">in</span> <span class="n">train_samples</span><span class="p">]))</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> <span class="n">x_val</span> <span class="o">=</span> <span class="n">vectorizer</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="n">s</span><span class="p">]</span> <span class="k">for</span> <span class="n">s</span> <span class="ow">in</span> <span class="n">val_samples</span><span class="p">]))</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> <span class="n">y_train</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">train_labels</span><span class="p">)</span> <span class="n">y_val</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">val_labels</span><span class="p">)</span> </code></pre></div> <p>We use categorical crossentropy as our loss since we're doing softmax classification. Moreover, we use <code>sparse_categorical_crossentropy</code> since our labels are integers.</p> <div class="codehilite"><pre><span></span><code><span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">loss</span><span class="o">=</span><span class="s2">"sparse_categorical_crossentropy"</span><span class="p">,</span> <span class="n">optimizer</span><span class="o">=</span><span class="s2">"rmsprop"</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="s2">"acc"</span><span class="p">]</span> <span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">128</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="p">(</span><span class="n">x_val</span><span class="p">,</span> <span class="n">y_val</span><span class="p">))</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 1/20 2/125 [37m━━━━━━━━━━━━━━━━━━━━ 9s 78ms/step - acc: 0.0352 - loss: 3.2164 WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1700434131.619687 6780 device_compiler.h:187] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process. 125/125 ━━━━━━━━━━━━━━━━━━━━ 22s 123ms/step - acc: 0.0926 - loss: 2.8961 - val_acc: 0.2451 - val_loss: 2.1965 Epoch 2/20 125/125 ━━━━━━━━━━━━━━━━━━━━ 10s 78ms/step - acc: 0.2628 - loss: 2.1377 - val_acc: 0.4421 - val_loss: 1.6594 Epoch 3/20 125/125 ━━━━━━━━━━━━━━━━━━━━ 10s 78ms/step - acc: 0.4504 - loss: 1.5765 - val_acc: 0.5849 - val_loss: 1.2577 Epoch 4/20 125/125 ━━━━━━━━━━━━━━━━━━━━ 10s 76ms/step - acc: 0.5711 - loss: 1.2639 - val_acc: 0.6277 - val_loss: 1.1153 Epoch 5/20 125/125 ━━━━━━━━━━━━━━━━━━━━ 9s 74ms/step - acc: 0.6430 - loss: 1.0318 - val_acc: 0.6684 - val_loss: 0.9902 Epoch 6/20 125/125 ━━━━━━━━━━━━━━━━━━━━ 9s 72ms/step - acc: 0.6990 - loss: 0.8844 - val_acc: 0.6619 - val_loss: 1.0109 Epoch 7/20 125/125 ━━━━━━━━━━━━━━━━━━━━ 9s 70ms/step - acc: 0.7330 - loss: 0.7614 - val_acc: 0.6832 - val_loss: 0.9585 Epoch 8/20 125/125 ━━━━━━━━━━━━━━━━━━━━ 8s 68ms/step - acc: 0.7795 - loss: 0.6328 - val_acc: 0.6847 - val_loss: 0.9917 Epoch 9/20 125/125 ━━━━━━━━━━━━━━━━━━━━ 8s 64ms/step - acc: 0.8203 - loss: 0.5242 - val_acc: 0.7187 - val_loss: 0.9224 Epoch 10/20 125/125 ━━━━━━━━━━━━━━━━━━━━ 8s 60ms/step - acc: 0.8506 - loss: 0.4265 - val_acc: 0.7342 - val_loss: 0.9098 Epoch 11/20 125/125 ━━━━━━━━━━━━━━━━━━━━ 7s 56ms/step - acc: 0.8756 - loss: 0.3659 - val_acc: 0.7204 - val_loss: 1.0022 Epoch 12/20 125/125 ━━━━━━━━━━━━━━━━━━━━ 7s 54ms/step - acc: 0.8921 - loss: 0.3079 - val_acc: 0.7209 - val_loss: 1.0477 Epoch 13/20 125/125 ━━━━━━━━━━━━━━━━━━━━ 7s 54ms/step - acc: 0.9077 - loss: 0.2767 - val_acc: 0.7169 - val_loss: 1.0915 Epoch 14/20 125/125 ━━━━━━━━━━━━━━━━━━━━ 6s 50ms/step - acc: 0.9244 - loss: 0.2253 - val_acc: 0.7382 - val_loss: 1.1397 Epoch 15/20 125/125 ━━━━━━━━━━━━━━━━━━━━ 6s 49ms/step - acc: 0.9301 - loss: 0.2054 - val_acc: 0.7562 - val_loss: 1.0984 Epoch 16/20 125/125 ━━━━━━━━━━━━━━━━━━━━ 5s 42ms/step - acc: 0.9373 - loss: 0.1769 - val_acc: 0.7387 - val_loss: 1.2294 Epoch 17/20 125/125 ━━━━━━━━━━━━━━━━━━━━ 5s 41ms/step - acc: 0.9467 - loss: 0.1626 - val_acc: 0.7009 - val_loss: 1.4906 Epoch 18/20 125/125 ━━━━━━━━━━━━━━━━━━━━ 5s 39ms/step - acc: 0.9471 - loss: 0.1544 - val_acc: 0.7184 - val_loss: 1.6050 Epoch 19/20 125/125 ━━━━━━━━━━━━━━━━━━━━ 5s 37ms/step - acc: 0.9532 - loss: 0.1388 - val_acc: 0.7407 - val_loss: 1.4360 Epoch 20/20 125/125 ━━━━━━━━━━━━━━━━━━━━ 5s 37ms/step - acc: 0.9519 - loss: 0.1388 - val_acc: 0.7309 - val_loss: 1.5327 <keras.src.callbacks.history.History at 0x7fbf50e6b910> </code></pre></div> </div> <hr /> <h2 id="export-an-endtoend-model">Export an end-to-end model</h2> <p>Now, we may want to export a <code>Model</code> object that takes as input a string of arbitrary length, rather than a sequence of indices. It would make the model much more portable, since you wouldn't have to worry about the input preprocessing pipeline.</p> <p>Our <code>vectorizer</code> is actually a Keras layer, so it's simple:</p> <div class="codehilite"><pre><span></span><code><span class="n">string_input</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,),</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"string"</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">vectorizer</span><span class="p">(</span><span class="n">string_input</span><span class="p">)</span> <span class="n">preds</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">end_to_end_model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">string_input</span><span class="p">,</span> <span class="n">preds</span><span class="p">)</span> <span class="n">probabilities</span> <span class="o">=</span> <span class="n">end_to_end_model</span><span class="p">(</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">convert_to_tensor</span><span class="p">(</span> <span class="p">[[</span><span class="s2">"this message is about computer graphics and 3D modeling"</span><span class="p">]]</span> <span class="p">)</span> <span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="n">class_names</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">probabilities</span><span class="p">[</span><span class="mi">0</span><span class="p">])])</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>comp.graphics </code></pre></div> </div> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#using-pretrained-word-embeddings'>Using pre-trained word embeddings</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#setup'>Setup</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#download-the-newsgroup20-data'>Download the Newsgroup20 data</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#lets-take-a-look-at-the-data'>Let's take a look at the data</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#shuffle-and-split-the-data-into-training-amp-validation-sets'>Shuffle and split the data into training & validation sets</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#create-a-vocabulary-index'>Create a vocabulary index</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#load-pretrained-word-embeddings'>Load pre-trained word embeddings</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#build-the-model'>Build the model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#train-the-model'>Train the model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#export-an-endtoend-model'>Export an end-to-end model</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>