CINXE.COM
English-to-Spanish translation with KerasHub
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/nlp/neural_machine_translation_with_keras_hub/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: English-to-Spanish translation with KerasHub"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: English-to-Spanish translation with KerasHub"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>English-to-Spanish translation with KerasHub</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink active" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink2" href="/examples/nlp/text_classification_from_scratch/">Text classification from scratch</a> <a class="nav-sublink2" href="/examples/nlp/active_learning_review_classification/">Review Classification using Active Learning</a> <a class="nav-sublink2" href="/examples/nlp/fnet_classification_with_keras_hub/">Text Classification using FNet</a> <a class="nav-sublink2" href="/examples/nlp/multi_label_classification/">Large-scale multi-label text classification</a> <a class="nav-sublink2" href="/examples/nlp/text_classification_with_transformer/">Text classification with Transformer</a> <a class="nav-sublink2" href="/examples/nlp/text_classification_with_switch_transformer/">Text classification with Switch Transformer</a> <a class="nav-sublink2" href="/examples/nlp/tweet-classification-using-tfdf/">Text classification using Decision Forests and pretrained embeddings</a> <a class="nav-sublink2" href="/examples/nlp/pretrained_word_embeddings/">Using pre-trained word embeddings</a> <a class="nav-sublink2" href="/examples/nlp/bidirectional_lstm_imdb/">Bidirectional LSTM on IMDB</a> <a class="nav-sublink2" href="/examples/nlp/data_parallel_training_with_keras_hub/">Data Parallel Training with KerasHub and tf.distribute</a> <a class="nav-sublink2 active" href="/examples/nlp/neural_machine_translation_with_keras_hub/">English-to-Spanish translation with KerasHub</a> <a class="nav-sublink2" href="/examples/nlp/neural_machine_translation_with_transformer/">English-to-Spanish translation with a sequence-to-sequence Transformer</a> <a class="nav-sublink2" href="/examples/nlp/lstm_seq2seq/">Character-level recurrent sequence-to-sequence model</a> <a class="nav-sublink2" href="/examples/nlp/multimodal_entailment/">Multimodal entailment</a> <a class="nav-sublink2" href="/examples/nlp/ner_transformers/">Named Entity Recognition using Transformers</a> <a class="nav-sublink2" href="/examples/nlp/text_extraction_with_bert/">Text Extraction with BERT</a> <a class="nav-sublink2" href="/examples/nlp/addition_rnn/">Sequence to sequence learning for performing number addition</a> <a class="nav-sublink2" href="/examples/nlp/semantic_similarity_with_keras_hub/">Semantic Similarity with KerasHub</a> <a class="nav-sublink2" href="/examples/nlp/semantic_similarity_with_bert/">Semantic Similarity with BERT</a> <a class="nav-sublink2" href="/examples/nlp/sentence_embeddings_with_sbert/">Sentence embeddings using Siamese RoBERTa-networks</a> <a class="nav-sublink2" href="/examples/nlp/masked_language_modeling/">End-to-end Masked Language Modeling with BERT</a> <a class="nav-sublink2" href="/examples/nlp/abstractive_summarization_with_bart/">Abstractive Text Summarization with BART</a> <a class="nav-sublink2" href="/examples/nlp/pretraining_BERT/">Pretraining BERT with Hugging Face Transformers</a> <a class="nav-sublink2" href="/examples/nlp/parameter_efficient_finetuning_of_gpt2_with_lora/">Parameter-efficient fine-tuning of GPT-2 with LoRA</a> <a class="nav-sublink2" href="/examples/nlp/mlm_training_tpus/">Training a language model from scratch with 🤗 Transformers and TPUs</a> <a class="nav-sublink2" href="/examples/nlp/multiple_choice_task_with_transfer_learning/">MultipleChoice Task with Transfer Learning</a> <a class="nav-sublink2" href="/examples/nlp/question_answering/">Question Answering with Hugging Face Transformers</a> <a class="nav-sublink2" href="/examples/nlp/t5_hf_summarization/">Abstractive Summarization with Hugging Face Transformers</a> <a class="nav-sublink" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparam Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/nlp/'>Natural Language Processing</a> / English-to-Spanish translation with KerasHub </div> <div class='k-content'> <h1 id="englishtospanish-translation-with-kerashub">English-to-Spanish translation with KerasHub</h1> <p><strong>Author:</strong> <a href="https://github.com/abheesht17/">Abheesht Sharma</a><br> <strong>Date created:</strong> 2022/05/26<br> <strong>Last modified:</strong> 2024/04/30<br> <strong>Description:</strong> Use KerasHub to train a sequence-to-sequence Transformer model on the machine translation task.</p> <div class='example_version_banner keras_3'>ⓘ This example uses Keras 3</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/nlp/ipynb/neural_machine_translation_with_keras_hub.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/examples/nlp/neural_machine_translation_with_keras_hub.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="introduction">Introduction</h2> <p>KerasHub provides building blocks for NLP (model layers, tokenizers, metrics, etc.) and makes it convenient to construct NLP pipelines.</p> <p>In this example, we'll use KerasHub layers to build an encoder-decoder Transformer model, and train it on the English-to-Spanish machine translation task.</p> <p>This example is based on the <a href="https://keras.io/examples/nlp/neural_machine_translation_with_transformer/">English-to-Spanish NMT example</a> by <a href="https://twitter.com/fchollet">fchollet</a>. The original example is more low-level and implements layers from scratch, whereas this example uses KerasHub to show some more advanced approaches, such as subword tokenization and using metrics to compute the quality of generated translations.</p> <p>You'll learn how to:</p> <ul> <li>Tokenize text using <a href="/keras_hub/api/tokenizers/word_piece_tokenizer#wordpiecetokenizer-class"><code>keras_hub.tokenizers.WordPieceTokenizer</code></a>.</li> <li>Implement a sequence-to-sequence Transformer model using KerasHub's <a href="/keras_hub/api/modeling_layers/transformer_encoder#transformerencoder-class"><code>keras_hub.layers.TransformerEncoder</code></a>, <a href="/keras_hub/api/modeling_layers/transformer_decoder#transformerdecoder-class"><code>keras_hub.layers.TransformerDecoder</code></a> and <a href="/keras_hub/api/modeling_layers/token_and_position_embedding#tokenandpositionembedding-class"><code>keras_hub.layers.TokenAndPositionEmbedding</code></a> layers, and train it.</li> <li>Use <code>keras_hub.samplers</code> to generate translations of unseen input sentences using the top-p decoding strategy!</li> </ul> <p>Don't worry if you aren't familiar with KerasHub. This tutorial will start with the basics. Let's dive right in!</p> <hr /> <h2 id="setup">Setup</h2> <p>Before we start implementing the pipeline, let's import all the libraries we need.</p> <div class="codehilite"><pre><span></span><code><span class="err">!</span><span class="n">pip</span> <span class="n">install</span> <span class="o">-</span><span class="n">q</span> <span class="o">--</span><span class="n">upgrade</span> <span class="n">rouge</span><span class="o">-</span><span class="n">score</span> <span class="err">!</span><span class="n">pip</span> <span class="n">install</span> <span class="o">-</span><span class="n">q</span> <span class="o">--</span><span class="n">upgrade</span> <span class="n">keras</span><span class="o">-</span><span class="n">hub</span> <span class="err">!</span><span class="n">pip</span> <span class="n">install</span> <span class="o">-</span><span class="n">q</span> <span class="o">--</span><span class="n">upgrade</span> <span class="n">keras</span> <span class="c1"># Upgrade to Keras 3.</span> </code></pre></div> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">keras_hub</span> <span class="kn">import</span> <span class="nn">pathlib</span> <span class="kn">import</span> <span class="nn">random</span> <span class="kn">import</span> <span class="nn">keras</span> <span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">ops</span> <span class="kn">import</span> <span class="nn">tensorflow.data</span> <span class="k">as</span> <span class="nn">tf_data</span> <span class="kn">from</span> <span class="nn">tensorflow_text.tools.wordpiece_vocab</span> <span class="kn">import</span> <span class="p">(</span> <span class="n">bert_vocab_from_dataset</span> <span class="k">as</span> <span class="n">bert_vocab</span><span class="p">,</span> <span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts. tensorflow 2.15.1 requires keras<2.16,>=2.15.0, but you have keras 3.3.3 which is incompatible.[31m </code></pre></div> </div> <p>Let's also define our parameters/hyperparameters.</p> <div class="codehilite"><pre><span></span><code><span class="n">BATCH_SIZE</span> <span class="o">=</span> <span class="mi">64</span> <span class="n">EPOCHS</span> <span class="o">=</span> <span class="mi">1</span> <span class="c1"># This should be at least 10 for convergence</span> <span class="n">MAX_SEQUENCE_LENGTH</span> <span class="o">=</span> <span class="mi">40</span> <span class="n">ENG_VOCAB_SIZE</span> <span class="o">=</span> <span class="mi">15000</span> <span class="n">SPA_VOCAB_SIZE</span> <span class="o">=</span> <span class="mi">15000</span> <span class="n">EMBED_DIM</span> <span class="o">=</span> <span class="mi">256</span> <span class="n">INTERMEDIATE_DIM</span> <span class="o">=</span> <span class="mi">2048</span> <span class="n">NUM_HEADS</span> <span class="o">=</span> <span class="mi">8</span> </code></pre></div> <hr /> <h2 id="downloading-the-data">Downloading the data</h2> <p>We'll be working with an English-to-Spanish translation dataset provided by <a href="https://www.manythings.org/anki/">Anki</a>. Let's download it:</p> <div class="codehilite"><pre><span></span><code><span class="n">text_file</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">get_file</span><span class="p">(</span> <span class="n">fname</span><span class="o">=</span><span class="s2">"spa-eng.zip"</span><span class="p">,</span> <span class="n">origin</span><span class="o">=</span><span class="s2">"http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip"</span><span class="p">,</span> <span class="n">extract</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="p">)</span> <span class="n">text_file</span> <span class="o">=</span> <span class="n">pathlib</span><span class="o">.</span><span class="n">Path</span><span class="p">(</span><span class="n">text_file</span><span class="p">)</span><span class="o">.</span><span class="n">parent</span> <span class="o">/</span> <span class="s2">"spa-eng"</span> <span class="o">/</span> <span class="s2">"spa.txt"</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Downloading data from http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip 2638744/2638744 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step </code></pre></div> </div> <hr /> <h2 id="parsing-the-data">Parsing the data</h2> <p>Each line contains an English sentence and its corresponding Spanish sentence. The English sentence is the <em>source sequence</em> and Spanish one is the <em>target sequence</em>. Before adding the text to a list, we convert it to lowercase.</p> <div class="codehilite"><pre><span></span><code><span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">text_file</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span> <span class="n">lines</span> <span class="o">=</span> <span class="n">f</span><span class="o">.</span><span class="n">read</span><span class="p">()</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">"</span><span class="se">\n</span><span class="s2">"</span><span class="p">)[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="n">text_pairs</span> <span class="o">=</span> <span class="p">[]</span> <span class="k">for</span> <span class="n">line</span> <span class="ow">in</span> <span class="n">lines</span><span class="p">:</span> <span class="n">eng</span><span class="p">,</span> <span class="n">spa</span> <span class="o">=</span> <span class="n">line</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">"</span><span class="se">\t</span><span class="s2">"</span><span class="p">)</span> <span class="n">eng</span> <span class="o">=</span> <span class="n">eng</span><span class="o">.</span><span class="n">lower</span><span class="p">()</span> <span class="n">spa</span> <span class="o">=</span> <span class="n">spa</span><span class="o">.</span><span class="n">lower</span><span class="p">()</span> <span class="n">text_pairs</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">eng</span><span class="p">,</span> <span class="n">spa</span><span class="p">))</span> </code></pre></div> <p>Here's what our sentence pairs look like:</p> <div class="codehilite"><pre><span></span><code><span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">5</span><span class="p">):</span> <span class="nb">print</span><span class="p">(</span><span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="n">text_pairs</span><span class="p">))</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>('tom heard that mary had bought a new computer.', 'tom oyó que mary se había comprado un computador nuevo.') ('will you stay at home?', '¿te vas a quedar en casa?') ('where is this train going?', '¿adónde va este tren?') ('tom panicked.', 'tom entró en pánico.') ("we'll help you rescue tom.", 'te ayudaremos a rescatar a tom.') </code></pre></div> </div> <p>Now, let's split the sentence pairs into a training set, a validation set, and a test set.</p> <div class="codehilite"><pre><span></span><code><span class="n">random</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="n">text_pairs</span><span class="p">)</span> <span class="n">num_val_samples</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="mf">0.15</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="n">text_pairs</span><span class="p">))</span> <span class="n">num_train_samples</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">text_pairs</span><span class="p">)</span> <span class="o">-</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">num_val_samples</span> <span class="n">train_pairs</span> <span class="o">=</span> <span class="n">text_pairs</span><span class="p">[:</span><span class="n">num_train_samples</span><span class="p">]</span> <span class="n">val_pairs</span> <span class="o">=</span> <span class="n">text_pairs</span><span class="p">[</span><span class="n">num_train_samples</span> <span class="p">:</span> <span class="n">num_train_samples</span> <span class="o">+</span> <span class="n">num_val_samples</span><span class="p">]</span> <span class="n">test_pairs</span> <span class="o">=</span> <span class="n">text_pairs</span><span class="p">[</span><span class="n">num_train_samples</span> <span class="o">+</span> <span class="n">num_val_samples</span> <span class="p">:]</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">text_pairs</span><span class="p">)</span><span class="si">}</span><span class="s2"> total pairs"</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">train_pairs</span><span class="p">)</span><span class="si">}</span><span class="s2"> training pairs"</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">val_pairs</span><span class="p">)</span><span class="si">}</span><span class="s2"> validation pairs"</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">test_pairs</span><span class="p">)</span><span class="si">}</span><span class="s2"> test pairs"</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>118964 total pairs 83276 training pairs 17844 validation pairs 17844 test pairs </code></pre></div> </div> <hr /> <h2 id="tokenizing-the-data">Tokenizing the data</h2> <p>We'll define two tokenizers - one for the source language (English), and the other for the target language (Spanish). We'll be using <a href="/keras_hub/api/tokenizers/word_piece_tokenizer#wordpiecetokenizer-class"><code>keras_hub.tokenizers.WordPieceTokenizer</code></a> to tokenize the text. <a href="/keras_hub/api/tokenizers/word_piece_tokenizer#wordpiecetokenizer-class"><code>keras_hub.tokenizers.WordPieceTokenizer</code></a> takes a WordPiece vocabulary and has functions for tokenizing the text, and detokenizing sequences of tokens.</p> <p>Before we define the two tokenizers, we first need to train them on the dataset we have. The WordPiece tokenization algorithm is a subword tokenization algorithm; training it on a corpus gives us a vocabulary of subwords. A subword tokenizer is a compromise between word tokenizers (word tokenizers need very large vocabularies for good coverage of input words), and character tokenizers (characters don't really encode meaning like words do). Luckily, KerasHub makes it very simple to train WordPiece on a corpus with the <a href="/keras_hub/api/tokenizers/compute_word_piece_vocabulary#computewordpiecevocabulary-function"><code>keras_hub.tokenizers.compute_word_piece_vocabulary</code></a> utility.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">train_word_piece</span><span class="p">(</span><span class="n">text_samples</span><span class="p">,</span> <span class="n">vocab_size</span><span class="p">,</span> <span class="n">reserved_tokens</span><span class="p">):</span> <span class="n">word_piece_ds</span> <span class="o">=</span> <span class="n">tf_data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">from_tensor_slices</span><span class="p">(</span><span class="n">text_samples</span><span class="p">)</span> <span class="n">vocab</span> <span class="o">=</span> <span class="n">keras_hub</span><span class="o">.</span><span class="n">tokenizers</span><span class="o">.</span><span class="n">compute_word_piece_vocabulary</span><span class="p">(</span> <span class="n">word_piece_ds</span><span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="mi">1000</span><span class="p">)</span><span class="o">.</span><span class="n">prefetch</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span> <span class="n">vocabulary_size</span><span class="o">=</span><span class="n">vocab_size</span><span class="p">,</span> <span class="n">reserved_tokens</span><span class="o">=</span><span class="n">reserved_tokens</span><span class="p">,</span> <span class="p">)</span> <span class="k">return</span> <span class="n">vocab</span> </code></pre></div> <p>Every vocabulary has a few special, reserved tokens. We have four such tokens:</p> <ul> <li><code>"[PAD]"</code> - Padding token. Padding tokens are appended to the input sequence length when the input sequence length is shorter than the maximum sequence length.</li> <li><code>"[UNK]"</code> - Unknown token.</li> <li><code>"[START]"</code> - Token that marks the start of the input sequence.</li> <li><code>"[END]"</code> - Token that marks the end of the input sequence.</li> </ul> <div class="codehilite"><pre><span></span><code><span class="n">reserved_tokens</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"[PAD]"</span><span class="p">,</span> <span class="s2">"[UNK]"</span><span class="p">,</span> <span class="s2">"[START]"</span><span class="p">,</span> <span class="s2">"[END]"</span><span class="p">]</span> <span class="n">eng_samples</span> <span class="o">=</span> <span class="p">[</span><span class="n">text_pair</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="k">for</span> <span class="n">text_pair</span> <span class="ow">in</span> <span class="n">train_pairs</span><span class="p">]</span> <span class="n">eng_vocab</span> <span class="o">=</span> <span class="n">train_word_piece</span><span class="p">(</span><span class="n">eng_samples</span><span class="p">,</span> <span class="n">ENG_VOCAB_SIZE</span><span class="p">,</span> <span class="n">reserved_tokens</span><span class="p">)</span> <span class="n">spa_samples</span> <span class="o">=</span> <span class="p">[</span><span class="n">text_pair</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="k">for</span> <span class="n">text_pair</span> <span class="ow">in</span> <span class="n">train_pairs</span><span class="p">]</span> <span class="n">spa_vocab</span> <span class="o">=</span> <span class="n">train_word_piece</span><span class="p">(</span><span class="n">spa_samples</span><span class="p">,</span> <span class="n">SPA_VOCAB_SIZE</span><span class="p">,</span> <span class="n">reserved_tokens</span><span class="p">)</span> </code></pre></div> <p>Let's see some tokens!</p> <div class="codehilite"><pre><span></span><code><span class="nb">print</span><span class="p">(</span><span class="s2">"English Tokens: "</span><span class="p">,</span> <span class="n">eng_vocab</span><span class="p">[</span><span class="mi">100</span><span class="p">:</span><span class="mi">110</span><span class="p">])</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Spanish Tokens: "</span><span class="p">,</span> <span class="n">spa_vocab</span><span class="p">[</span><span class="mi">100</span><span class="p">:</span><span class="mi">110</span><span class="p">])</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>English Tokens: ['at', 'know', 'him', 'there', 'go', 'they', 'her', 'has', 'time', 'will'] Spanish Tokens: ['le', 'para', 'te', 'mary', 'las', 'más', 'al', 'yo', 'tu', 'estoy'] </code></pre></div> </div> <p>Now, let's define the tokenizers. We will configure the tokenizers with the the vocabularies trained above.</p> <div class="codehilite"><pre><span></span><code><span class="n">eng_tokenizer</span> <span class="o">=</span> <span class="n">keras_hub</span><span class="o">.</span><span class="n">tokenizers</span><span class="o">.</span><span class="n">WordPieceTokenizer</span><span class="p">(</span> <span class="n">vocabulary</span><span class="o">=</span><span class="n">eng_vocab</span><span class="p">,</span> <span class="n">lowercase</span><span class="o">=</span><span class="kc">False</span> <span class="p">)</span> <span class="n">spa_tokenizer</span> <span class="o">=</span> <span class="n">keras_hub</span><span class="o">.</span><span class="n">tokenizers</span><span class="o">.</span><span class="n">WordPieceTokenizer</span><span class="p">(</span> <span class="n">vocabulary</span><span class="o">=</span><span class="n">spa_vocab</span><span class="p">,</span> <span class="n">lowercase</span><span class="o">=</span><span class="kc">False</span> <span class="p">)</span> </code></pre></div> <p>Let's try and tokenize a sample from our dataset! To verify whether the text has been tokenized correctly, we can also detokenize the list of tokens back to the original text.</p> <div class="codehilite"><pre><span></span><code><span class="n">eng_input_ex</span> <span class="o">=</span> <span class="n">text_pairs</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span> <span class="n">eng_tokens_ex</span> <span class="o">=</span> <span class="n">eng_tokenizer</span><span class="o">.</span><span class="n">tokenize</span><span class="p">(</span><span class="n">eng_input_ex</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"English sentence: "</span><span class="p">,</span> <span class="n">eng_input_ex</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Tokens: "</span><span class="p">,</span> <span class="n">eng_tokens_ex</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span> <span class="s2">"Recovered text after detokenizing: "</span><span class="p">,</span> <span class="n">eng_tokenizer</span><span class="o">.</span><span class="n">detokenize</span><span class="p">(</span><span class="n">eng_tokens_ex</span><span class="p">),</span> <span class="p">)</span> <span class="nb">print</span><span class="p">()</span> <span class="n">spa_input_ex</span> <span class="o">=</span> <span class="n">text_pairs</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">1</span><span class="p">]</span> <span class="n">spa_tokens_ex</span> <span class="o">=</span> <span class="n">spa_tokenizer</span><span class="o">.</span><span class="n">tokenize</span><span class="p">(</span><span class="n">spa_input_ex</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Spanish sentence: "</span><span class="p">,</span> <span class="n">spa_input_ex</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Tokens: "</span><span class="p">,</span> <span class="n">spa_tokens_ex</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span> <span class="s2">"Recovered text after detokenizing: "</span><span class="p">,</span> <span class="n">spa_tokenizer</span><span class="o">.</span><span class="n">detokenize</span><span class="p">(</span><span class="n">spa_tokens_ex</span><span class="p">),</span> <span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>English sentence: i am leaving the books here. Tokens: tf.Tensor([ 35 163 931 66 356 119 12], shape=(7,), dtype=int32) Recovered text after detokenizing: tf.Tensor(b'i am leaving the books here .', shape=(), dtype=string) </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Spanish sentence: dejo los libros aquí. Tokens: tf.Tensor([2962 93 350 122 14], shape=(5,), dtype=int32) Recovered text after detokenizing: tf.Tensor(b'dejo los libros aqu\xc3\xad .', shape=(), dtype=string) </code></pre></div> </div> <hr /> <h2 id="format-datasets">Format datasets</h2> <p>Next, we'll format our datasets.</p> <p>At each training step, the model will seek to predict target words N+1 (and beyond) using the source sentence and the target words 0 to N.</p> <p>As such, the training dataset will yield a tuple <code>(inputs, targets)</code>, where:</p> <ul> <li><code>inputs</code> is a dictionary with the keys <code>encoder_inputs</code> and <code>decoder_inputs</code>. <code>encoder_inputs</code> is the tokenized source sentence and <code>decoder_inputs</code> is the target sentence "so far", that is to say, the words 0 to N used to predict word N+1 (and beyond) in the target sentence.</li> <li><code>target</code> is the target sentence offset by one step: it provides the next words in the target sentence – what the model will try to predict.</li> </ul> <p>We will add special tokens, <code>"[START]"</code> and <code>"[END]"</code>, to the input Spanish sentence after tokenizing the text. We will also pad the input to a fixed length. This can be easily done using <a href="/keras_hub/api/preprocessing_layers/start_end_packer#startendpacker-class"><code>keras_hub.layers.StartEndPacker</code></a>.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">preprocess_batch</span><span class="p">(</span><span class="n">eng</span><span class="p">,</span> <span class="n">spa</span><span class="p">):</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">spa</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span> <span class="n">eng</span> <span class="o">=</span> <span class="n">eng_tokenizer</span><span class="p">(</span><span class="n">eng</span><span class="p">)</span> <span class="n">spa</span> <span class="o">=</span> <span class="n">spa_tokenizer</span><span class="p">(</span><span class="n">spa</span><span class="p">)</span> <span class="c1"># Pad `eng` to `MAX_SEQUENCE_LENGTH`.</span> <span class="n">eng_start_end_packer</span> <span class="o">=</span> <span class="n">keras_hub</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">StartEndPacker</span><span class="p">(</span> <span class="n">sequence_length</span><span class="o">=</span><span class="n">MAX_SEQUENCE_LENGTH</span><span class="p">,</span> <span class="n">pad_value</span><span class="o">=</span><span class="n">eng_tokenizer</span><span class="o">.</span><span class="n">token_to_id</span><span class="p">(</span><span class="s2">"[PAD]"</span><span class="p">),</span> <span class="p">)</span> <span class="n">eng</span> <span class="o">=</span> <span class="n">eng_start_end_packer</span><span class="p">(</span><span class="n">eng</span><span class="p">)</span> <span class="c1"># Add special tokens (`"[START]"` and `"[END]"`) to `spa` and pad it as well.</span> <span class="n">spa_start_end_packer</span> <span class="o">=</span> <span class="n">keras_hub</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">StartEndPacker</span><span class="p">(</span> <span class="n">sequence_length</span><span class="o">=</span><span class="n">MAX_SEQUENCE_LENGTH</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">start_value</span><span class="o">=</span><span class="n">spa_tokenizer</span><span class="o">.</span><span class="n">token_to_id</span><span class="p">(</span><span class="s2">"[START]"</span><span class="p">),</span> <span class="n">end_value</span><span class="o">=</span><span class="n">spa_tokenizer</span><span class="o">.</span><span class="n">token_to_id</span><span class="p">(</span><span class="s2">"[END]"</span><span class="p">),</span> <span class="n">pad_value</span><span class="o">=</span><span class="n">spa_tokenizer</span><span class="o">.</span><span class="n">token_to_id</span><span class="p">(</span><span class="s2">"[PAD]"</span><span class="p">),</span> <span class="p">)</span> <span class="n">spa</span> <span class="o">=</span> <span class="n">spa_start_end_packer</span><span class="p">(</span><span class="n">spa</span><span class="p">)</span> <span class="k">return</span> <span class="p">(</span> <span class="p">{</span> <span class="s2">"encoder_inputs"</span><span class="p">:</span> <span class="n">eng</span><span class="p">,</span> <span class="s2">"decoder_inputs"</span><span class="p">:</span> <span class="n">spa</span><span class="p">[:,</span> <span class="p">:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="p">},</span> <span class="n">spa</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">:],</span> <span class="p">)</span> <span class="k">def</span> <span class="nf">make_dataset</span><span class="p">(</span><span class="n">pairs</span><span class="p">):</span> <span class="n">eng_texts</span><span class="p">,</span> <span class="n">spa_texts</span> <span class="o">=</span> <span class="nb">zip</span><span class="p">(</span><span class="o">*</span><span class="n">pairs</span><span class="p">)</span> <span class="n">eng_texts</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">eng_texts</span><span class="p">)</span> <span class="n">spa_texts</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">spa_texts</span><span class="p">)</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">tf_data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">from_tensor_slices</span><span class="p">((</span><span class="n">eng_texts</span><span class="p">,</span> <span class="n">spa_texts</span><span class="p">))</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="n">BATCH_SIZE</span><span class="p">)</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">preprocess_batch</span><span class="p">,</span> <span class="n">num_parallel_calls</span><span class="o">=</span><span class="n">tf_data</span><span class="o">.</span><span class="n">AUTOTUNE</span><span class="p">)</span> <span class="k">return</span> <span class="n">dataset</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="mi">2048</span><span class="p">)</span><span class="o">.</span><span class="n">prefetch</span><span class="p">(</span><span class="mi">16</span><span class="p">)</span><span class="o">.</span><span class="n">cache</span><span class="p">()</span> <span class="n">train_ds</span> <span class="o">=</span> <span class="n">make_dataset</span><span class="p">(</span><span class="n">train_pairs</span><span class="p">)</span> <span class="n">val_ds</span> <span class="o">=</span> <span class="n">make_dataset</span><span class="p">(</span><span class="n">val_pairs</span><span class="p">)</span> </code></pre></div> <p>Let's take a quick look at the sequence shapes (we have batches of 64 pairs, and all sequences are 40 steps long):</p> <div class="codehilite"><pre><span></span><code><span class="k">for</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span> <span class="ow">in</span> <span class="n">train_ds</span><span class="o">.</span><span class="n">take</span><span class="p">(</span><span class="mi">1</span><span class="p">):</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s1">'inputs["encoder_inputs"].shape: </span><span class="si">{</span><span class="n">inputs</span><span class="p">[</span><span class="s2">"encoder_inputs"</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s1">'inputs["decoder_inputs"].shape: </span><span class="si">{</span><span class="n">inputs</span><span class="p">[</span><span class="s2">"decoder_inputs"</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"targets.shape: </span><span class="si">{</span><span class="n">targets</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>inputs["encoder_inputs"].shape: (64, 40) inputs["decoder_inputs"].shape: (64, 40) targets.shape: (64, 40) </code></pre></div> </div> <hr /> <h2 id="building-the-model">Building the model</h2> <p>Now, let's move on to the exciting part - defining our model! We first need an embedding layer, i.e., a vector for every token in our input sequence. This embedding layer can be initialised randomly. We also need a positional embedding layer which encodes the word order in the sequence. The convention is to add these two embeddings. KerasHub has a <code>keras_hub.layers.TokenAndPositionEmbedding</code> layer which does all of the above steps for us.</p> <p>Our sequence-to-sequence Transformer consists of a <a href="/keras_hub/api/modeling_layers/transformer_encoder#transformerencoder-class"><code>keras_hub.layers.TransformerEncoder</code></a> layer and a <a href="/keras_hub/api/modeling_layers/transformer_decoder#transformerdecoder-class"><code>keras_hub.layers.TransformerDecoder</code></a> layer chained together.</p> <p>The source sequence will be passed to <a href="/keras_hub/api/modeling_layers/transformer_encoder#transformerencoder-class"><code>keras_hub.layers.TransformerEncoder</code></a>, which will produce a new representation of it. This new representation will then be passed to the <a href="/keras_hub/api/modeling_layers/transformer_decoder#transformerdecoder-class"><code>keras_hub.layers.TransformerDecoder</code></a>, together with the target sequence so far (target words 0 to N). The <a href="/keras_hub/api/modeling_layers/transformer_decoder#transformerdecoder-class"><code>keras_hub.layers.TransformerDecoder</code></a> will then seek to predict the next words in the target sequence (N+1 and beyond).</p> <p>A key detail that makes this possible is causal masking. The <a href="/keras_hub/api/modeling_layers/transformer_decoder#transformerdecoder-class"><code>keras_hub.layers.TransformerDecoder</code></a> sees the entire sequence at once, and thus we must make sure that it only uses information from target tokens 0 to N when predicting token N+1 (otherwise, it could use information from the future, which would result in a model that cannot be used at inference time). Causal masking is enabled by default in <a href="/keras_hub/api/modeling_layers/transformer_decoder#transformerdecoder-class"><code>keras_hub.layers.TransformerDecoder</code></a>.</p> <p>We also need to mask the padding tokens (<code>"[PAD]"</code>). For this, we can set the <code>mask_zero</code> argument of the <a href="/keras_hub/api/modeling_layers/token_and_position_embedding#tokenandpositionembedding-class"><code>keras_hub.layers.TokenAndPositionEmbedding</code></a> layer to True. This will then be propagated to all subsequent layers.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Encoder</span> <span class="n">encoder_inputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="kc">None</span><span class="p">,),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"encoder_inputs"</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras_hub</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">TokenAndPositionEmbedding</span><span class="p">(</span> <span class="n">vocabulary_size</span><span class="o">=</span><span class="n">ENG_VOCAB_SIZE</span><span class="p">,</span> <span class="n">sequence_length</span><span class="o">=</span><span class="n">MAX_SEQUENCE_LENGTH</span><span class="p">,</span> <span class="n">embedding_dim</span><span class="o">=</span><span class="n">EMBED_DIM</span><span class="p">,</span> <span class="p">)(</span><span class="n">encoder_inputs</span><span class="p">)</span> <span class="n">encoder_outputs</span> <span class="o">=</span> <span class="n">keras_hub</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">TransformerEncoder</span><span class="p">(</span> <span class="n">intermediate_dim</span><span class="o">=</span><span class="n">INTERMEDIATE_DIM</span><span class="p">,</span> <span class="n">num_heads</span><span class="o">=</span><span class="n">NUM_HEADS</span> <span class="p">)(</span><span class="n">inputs</span><span class="o">=</span><span class="n">x</span><span class="p">)</span> <span class="n">encoder</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">encoder_inputs</span><span class="p">,</span> <span class="n">encoder_outputs</span><span class="p">)</span> <span class="c1"># Decoder</span> <span class="n">decoder_inputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="kc">None</span><span class="p">,),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"decoder_inputs"</span><span class="p">)</span> <span class="n">encoded_seq_inputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="kc">None</span><span class="p">,</span> <span class="n">EMBED_DIM</span><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"decoder_state_inputs"</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras_hub</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">TokenAndPositionEmbedding</span><span class="p">(</span> <span class="n">vocabulary_size</span><span class="o">=</span><span class="n">SPA_VOCAB_SIZE</span><span class="p">,</span> <span class="n">sequence_length</span><span class="o">=</span><span class="n">MAX_SEQUENCE_LENGTH</span><span class="p">,</span> <span class="n">embedding_dim</span><span class="o">=</span><span class="n">EMBED_DIM</span><span class="p">,</span> <span class="p">)(</span><span class="n">decoder_inputs</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras_hub</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">TransformerDecoder</span><span class="p">(</span> <span class="n">intermediate_dim</span><span class="o">=</span><span class="n">INTERMEDIATE_DIM</span><span class="p">,</span> <span class="n">num_heads</span><span class="o">=</span><span class="n">NUM_HEADS</span> <span class="p">)(</span><span class="n">decoder_sequence</span><span class="o">=</span><span class="n">x</span><span class="p">,</span> <span class="n">encoder_sequence</span><span class="o">=</span><span class="n">encoded_seq_inputs</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="mf">0.5</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">decoder_outputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">SPA_VOCAB_SIZE</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"softmax"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">decoder</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span> <span class="p">[</span> <span class="n">decoder_inputs</span><span class="p">,</span> <span class="n">encoded_seq_inputs</span><span class="p">,</span> <span class="p">],</span> <span class="n">decoder_outputs</span><span class="p">,</span> <span class="p">)</span> <span class="n">decoder_outputs</span> <span class="o">=</span> <span class="n">decoder</span><span class="p">([</span><span class="n">decoder_inputs</span><span class="p">,</span> <span class="n">encoder_outputs</span><span class="p">])</span> <span class="n">transformer</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span> <span class="p">[</span><span class="n">encoder_inputs</span><span class="p">,</span> <span class="n">decoder_inputs</span><span class="p">],</span> <span class="n">decoder_outputs</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"transformer"</span><span class="p">,</span> <span class="p">)</span> </code></pre></div> <hr /> <h2 id="training-our-model">Training our model</h2> <p>We'll use accuracy as a quick way to monitor training progress on the validation data. Note that machine translation typically uses BLEU scores as well as other metrics, rather than accuracy. However, in order to use metrics like ROUGE, BLEU, etc. we will have decode the probabilities and generate the text. Text generation is computationally expensive, and performing this during training is not recommended.</p> <p>Here we only train for 1 epoch, but to get the model to actually converge you should train for at least 10 epochs.</p> <div class="codehilite"><pre><span></span><code><span class="n">transformer</span><span class="o">.</span><span class="n">summary</span><span class="p">()</span> <span class="n">transformer</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="s2">"rmsprop"</span><span class="p">,</span> <span class="n">loss</span><span class="o">=</span><span class="s2">"sparse_categorical_crossentropy"</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="s2">"accuracy"</span><span class="p">]</span> <span class="p">)</span> <span class="n">transformer</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">train_ds</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="n">EPOCHS</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="n">val_ds</span><span class="p">)</span> </code></pre></div> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold">Model: "transformer"</span> </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace">┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓ ┃<span style="font-weight: bold"> Layer (type) </span>┃<span style="font-weight: bold"> Output Shape </span>┃<span style="font-weight: bold"> Param # </span>┃<span style="font-weight: bold"> Connected to </span>┃ ┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩ │ encoder_inputs │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ - │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">InputLayer</span>) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ token_and_position… │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">256</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">3,850,240</span> │ encoder_inputs[<span style="color: #00af00; text-decoration-color: #00af00">0</span>… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">TokenAndPositionE…</span> │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ decoder_inputs │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ - │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">InputLayer</span>) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ transformer_encoder │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">256</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">1,315,072</span> │ token_and_positi… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">TransformerEncode…</span> │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ functional_3 │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, │ <span style="color: #00af00; text-decoration-color: #00af00">9,283,992</span> │ decoder_inputs[<span style="color: #00af00; text-decoration-color: #00af00">0</span>… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Functional</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">15000</span>) │ │ transformer_enco… │ └─────────────────────┴───────────────────┴────────────┴───────────────────┘ </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Total params: </span><span style="color: #00af00; text-decoration-color: #00af00">14,449,304</span> (55.12 MB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">14,449,304</span> (55.12 MB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Non-trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">0</span> (0.00 B) </pre> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 1302/1302 ━━━━━━━━━━━━━━━━━━━━ 1701s 1s/step - accuracy: 0.8168 - loss: 1.4819 - val_accuracy: 0.8650 - val_loss: 0.8129 <keras.src.callbacks.history.History at 0x7efdd7ee6a50> </code></pre></div> </div> <hr /> <h2 id="decoding-test-sentences-qualitative-analysis">Decoding test sentences (qualitative analysis)</h2> <p>Finally, let's demonstrate how to translate brand new English sentences. We simply feed into the model the tokenized English sentence as well as the target token <code>"[START]"</code>. The model outputs probabilities of the next token. We then we repeatedly generated the next token conditioned on the tokens generated so far, until we hit the token <code>"[END]"</code>.</p> <p>For decoding, we will use the <code>keras_hub.samplers</code> module from KerasHub. Greedy Decoding is a text decoding method which outputs the most likely next token at each time step, i.e., the token with the highest probability.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">decode_sequences</span><span class="p">(</span><span class="n">input_sentences</span><span class="p">):</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="mi">1</span> <span class="c1"># Tokenize the encoder input.</span> <span class="n">encoder_input_tokens</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">convert_to_tensor</span><span class="p">(</span><span class="n">eng_tokenizer</span><span class="p">(</span><span class="n">input_sentences</span><span class="p">))</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">encoder_input_tokens</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="o"><</span> <span class="n">MAX_SEQUENCE_LENGTH</span><span class="p">:</span> <span class="n">pads</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">full</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="n">MAX_SEQUENCE_LENGTH</span> <span class="o">-</span> <span class="nb">len</span><span class="p">(</span><span class="n">encoder_input_tokens</span><span class="p">[</span><span class="mi">0</span><span class="p">])),</span> <span class="mi">0</span><span class="p">)</span> <span class="n">encoder_input_tokens</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">concatenate</span><span class="p">(</span> <span class="p">[</span><span class="n">encoder_input_tokens</span><span class="o">.</span><span class="n">to_tensor</span><span class="p">(),</span> <span class="n">pads</span><span class="p">],</span> <span class="mi">1</span> <span class="p">)</span> <span class="c1"># Define a function that outputs the next token's probability given the</span> <span class="c1"># input sequence.</span> <span class="k">def</span> <span class="nf">next</span><span class="p">(</span><span class="n">prompt</span><span class="p">,</span> <span class="n">cache</span><span class="p">,</span> <span class="n">index</span><span class="p">):</span> <span class="n">logits</span> <span class="o">=</span> <span class="n">transformer</span><span class="p">([</span><span class="n">encoder_input_tokens</span><span class="p">,</span> <span class="n">prompt</span><span class="p">])[:,</span> <span class="n">index</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span> <span class="p">:]</span> <span class="c1"># Ignore hidden states for now; only needed for contrastive search.</span> <span class="n">hidden_states</span> <span class="o">=</span> <span class="kc">None</span> <span class="k">return</span> <span class="n">logits</span><span class="p">,</span> <span class="n">hidden_states</span><span class="p">,</span> <span class="n">cache</span> <span class="c1"># Build a prompt of length 40 with a start token and padding tokens.</span> <span class="n">length</span> <span class="o">=</span> <span class="mi">40</span> <span class="n">start</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">full</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">spa_tokenizer</span><span class="o">.</span><span class="n">token_to_id</span><span class="p">(</span><span class="s2">"[START]"</span><span class="p">))</span> <span class="n">pad</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">full</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">length</span> <span class="o">-</span> <span class="mi">1</span><span class="p">),</span> <span class="n">spa_tokenizer</span><span class="o">.</span><span class="n">token_to_id</span><span class="p">(</span><span class="s2">"[PAD]"</span><span class="p">))</span> <span class="n">prompt</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">concatenate</span><span class="p">((</span><span class="n">start</span><span class="p">,</span> <span class="n">pad</span><span class="p">),</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="n">generated_tokens</span> <span class="o">=</span> <span class="n">keras_hub</span><span class="o">.</span><span class="n">samplers</span><span class="o">.</span><span class="n">GreedySampler</span><span class="p">()(</span> <span class="nb">next</span><span class="p">,</span> <span class="n">prompt</span><span class="p">,</span> <span class="n">stop_token_ids</span><span class="o">=</span><span class="p">[</span><span class="n">spa_tokenizer</span><span class="o">.</span><span class="n">token_to_id</span><span class="p">(</span><span class="s2">"[END]"</span><span class="p">)],</span> <span class="n">index</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="c1"># Start sampling after start token.</span> <span class="p">)</span> <span class="n">generated_sentences</span> <span class="o">=</span> <span class="n">spa_tokenizer</span><span class="o">.</span><span class="n">detokenize</span><span class="p">(</span><span class="n">generated_tokens</span><span class="p">)</span> <span class="k">return</span> <span class="n">generated_sentences</span> <span class="n">test_eng_texts</span> <span class="o">=</span> <span class="p">[</span><span class="n">pair</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="k">for</span> <span class="n">pair</span> <span class="ow">in</span> <span class="n">test_pairs</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">):</span> <span class="n">input_sentence</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="n">test_eng_texts</span><span class="p">)</span> <span class="n">translated</span> <span class="o">=</span> <span class="n">decode_sequences</span><span class="p">([</span><span class="n">input_sentence</span><span class="p">])</span> <span class="n">translated</span> <span class="o">=</span> <span class="n">translated</span><span class="o">.</span><span class="n">numpy</span><span class="p">()[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">decode</span><span class="p">(</span><span class="s2">"utf-8"</span><span class="p">)</span> <span class="n">translated</span> <span class="o">=</span> <span class="p">(</span> <span class="n">translated</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">"[PAD]"</span><span class="p">,</span> <span class="s2">""</span><span class="p">)</span> <span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">"[START]"</span><span class="p">,</span> <span class="s2">""</span><span class="p">)</span> <span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">"[END]"</span><span class="p">,</span> <span class="s2">""</span><span class="p">)</span> <span class="o">.</span><span class="n">strip</span><span class="p">()</span> <span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"** Example </span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s2"> **"</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="n">input_sentence</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="n">translated</span><span class="p">)</span> <span class="nb">print</span><span class="p">()</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1714519073.816969 34774 device_compiler.h:186] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process. ** Example 0 ** i got the ticket free of charge. me pregunto la comprome . </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>** Example 1 ** i think maybe that's all you have to do. creo que tom le dije que hacer eso . </code></pre></div> </div> <hr /> <h2 id="evaluating-our-model-quantitative-analysis">Evaluating our model (quantitative analysis)</h2> <p>There are many metrics which are used for text generation tasks. Here, to evaluate translations generated by our model, let's compute the ROUGE-1 and ROUGE-2 scores. Essentially, ROUGE-N is a score based on the number of common n-grams between the reference text and the generated text. ROUGE-1 and ROUGE-2 use the number of common unigrams and bigrams, respectively.</p> <p>We will calculate the score over 30 test samples (since decoding is an expensive process).</p> <div class="codehilite"><pre><span></span><code><span class="n">rouge_1</span> <span class="o">=</span> <span class="n">keras_hub</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">RougeN</span><span class="p">(</span><span class="n">order</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="n">rouge_2</span> <span class="o">=</span> <span class="n">keras_hub</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">RougeN</span><span class="p">(</span><span class="n">order</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span> <span class="k">for</span> <span class="n">test_pair</span> <span class="ow">in</span> <span class="n">test_pairs</span><span class="p">[:</span><span class="mi">30</span><span class="p">]:</span> <span class="n">input_sentence</span> <span class="o">=</span> <span class="n">test_pair</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="n">reference_sentence</span> <span class="o">=</span> <span class="n">test_pair</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="n">translated_sentence</span> <span class="o">=</span> <span class="n">decode_sequences</span><span class="p">([</span><span class="n">input_sentence</span><span class="p">])</span> <span class="n">translated_sentence</span> <span class="o">=</span> <span class="n">translated_sentence</span><span class="o">.</span><span class="n">numpy</span><span class="p">()[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">decode</span><span class="p">(</span><span class="s2">"utf-8"</span><span class="p">)</span> <span class="n">translated_sentence</span> <span class="o">=</span> <span class="p">(</span> <span class="n">translated_sentence</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">"[PAD]"</span><span class="p">,</span> <span class="s2">""</span><span class="p">)</span> <span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">"[START]"</span><span class="p">,</span> <span class="s2">""</span><span class="p">)</span> <span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">"[END]"</span><span class="p">,</span> <span class="s2">""</span><span class="p">)</span> <span class="o">.</span><span class="n">strip</span><span class="p">()</span> <span class="p">)</span> <span class="n">rouge_1</span><span class="p">(</span><span class="n">reference_sentence</span><span class="p">,</span> <span class="n">translated_sentence</span><span class="p">)</span> <span class="n">rouge_2</span><span class="p">(</span><span class="n">reference_sentence</span><span class="p">,</span> <span class="n">translated_sentence</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"ROUGE-1 Score: "</span><span class="p">,</span> <span class="n">rouge_1</span><span class="o">.</span><span class="n">result</span><span class="p">())</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"ROUGE-2 Score: "</span><span class="p">,</span> <span class="n">rouge_2</span><span class="o">.</span><span class="n">result</span><span class="p">())</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>ROUGE-1 Score: {'precision': <tf.Tensor: shape=(), dtype=float32, numpy=0.30989552>, 'recall': <tf.Tensor: shape=(), dtype=float32, numpy=0.37136248>, 'f1_score': <tf.Tensor: shape=(), dtype=float32, numpy=0.33032653>} ROUGE-2 Score: {'precision': <tf.Tensor: shape=(), dtype=float32, numpy=0.08999339>, 'recall': <tf.Tensor: shape=(), dtype=float32, numpy=0.09524643>, 'f1_score': <tf.Tensor: shape=(), dtype=float32, numpy=0.08855649>} </code></pre></div> </div> <p>After 10 epochs, the scores are as follows:</p> <table> <thead> <tr> <th style="text-align: center;"></th> <th style="text-align: center;"><strong>ROUGE-1</strong></th> <th style="text-align: center;"><strong>ROUGE-2</strong></th> </tr> </thead> <tbody> <tr> <td style="text-align: center;"><strong>Precision</strong></td> <td style="text-align: center;">0.568</td> <td style="text-align: center;">0.374</td> </tr> <tr> <td style="text-align: center;"><strong>Recall</strong></td> <td style="text-align: center;">0.615</td> <td style="text-align: center;">0.394</td> </tr> <tr> <td style="text-align: center;"><strong>F1 Score</strong></td> <td style="text-align: center;">0.579</td> <td style="text-align: center;">0.381</td> </tr> </tbody> </table> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#englishtospanish-translation-with-kerashub'>English-to-Spanish translation with KerasHub</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#setup'>Setup</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#downloading-the-data'>Downloading the data</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#parsing-the-data'>Parsing the data</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#tokenizing-the-data'>Tokenizing the data</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#format-datasets'>Format datasets</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#building-the-model'>Building the model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#training-our-model'>Training our model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#decoding-test-sentences-qualitative-analysis'>Decoding test sentences (qualitative analysis)</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#evaluating-our-model-quantitative-analysis'>Evaluating our model (quantitative analysis)</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>