CINXE.COM
English-to-Spanish translation with a sequence-to-sequence Transformer
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/nlp/neural_machine_translation_with_transformer/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: English-to-Spanish translation with a sequence-to-sequence Transformer"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: English-to-Spanish translation with a sequence-to-sequence Transformer"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>English-to-Spanish translation with a sequence-to-sequence Transformer</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink active" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink2" href="/examples/nlp/text_classification_from_scratch/">Text classification from scratch</a> <a class="nav-sublink2" href="/examples/nlp/active_learning_review_classification/">Review Classification using Active Learning</a> <a class="nav-sublink2" href="/examples/nlp/fnet_classification_with_keras_hub/">Text Classification using FNet</a> <a class="nav-sublink2" href="/examples/nlp/multi_label_classification/">Large-scale multi-label text classification</a> <a class="nav-sublink2" href="/examples/nlp/text_classification_with_transformer/">Text classification with Transformer</a> <a class="nav-sublink2" href="/examples/nlp/text_classification_with_switch_transformer/">Text classification with Switch Transformer</a> <a class="nav-sublink2" href="/examples/nlp/tweet-classification-using-tfdf/">Text classification using Decision Forests and pretrained embeddings</a> <a class="nav-sublink2" href="/examples/nlp/pretrained_word_embeddings/">Using pre-trained word embeddings</a> <a class="nav-sublink2" href="/examples/nlp/bidirectional_lstm_imdb/">Bidirectional LSTM on IMDB</a> <a class="nav-sublink2" href="/examples/nlp/data_parallel_training_with_keras_hub/">Data Parallel Training with KerasHub and tf.distribute</a> <a class="nav-sublink2" href="/examples/nlp/neural_machine_translation_with_keras_hub/">English-to-Spanish translation with KerasHub</a> <a class="nav-sublink2 active" href="/examples/nlp/neural_machine_translation_with_transformer/">English-to-Spanish translation with a sequence-to-sequence Transformer</a> <a class="nav-sublink2" href="/examples/nlp/lstm_seq2seq/">Character-level recurrent sequence-to-sequence model</a> <a class="nav-sublink2" href="/examples/nlp/multimodal_entailment/">Multimodal entailment</a> <a class="nav-sublink2" href="/examples/nlp/ner_transformers/">Named Entity Recognition using Transformers</a> <a class="nav-sublink2" href="/examples/nlp/text_extraction_with_bert/">Text Extraction with BERT</a> <a class="nav-sublink2" href="/examples/nlp/addition_rnn/">Sequence to sequence learning for performing number addition</a> <a class="nav-sublink2" href="/examples/nlp/semantic_similarity_with_keras_hub/">Semantic Similarity with KerasHub</a> <a class="nav-sublink2" href="/examples/nlp/semantic_similarity_with_bert/">Semantic Similarity with BERT</a> <a class="nav-sublink2" href="/examples/nlp/sentence_embeddings_with_sbert/">Sentence embeddings using Siamese RoBERTa-networks</a> <a class="nav-sublink2" href="/examples/nlp/masked_language_modeling/">End-to-end Masked Language Modeling with BERT</a> <a class="nav-sublink2" href="/examples/nlp/abstractive_summarization_with_bart/">Abstractive Text Summarization with BART</a> <a class="nav-sublink2" href="/examples/nlp/pretraining_BERT/">Pretraining BERT with Hugging Face Transformers</a> <a class="nav-sublink2" href="/examples/nlp/parameter_efficient_finetuning_of_gpt2_with_lora/">Parameter-efficient fine-tuning of GPT-2 with LoRA</a> <a class="nav-sublink2" href="/examples/nlp/mlm_training_tpus/">Training a language model from scratch with 🤗 Transformers and TPUs</a> <a class="nav-sublink2" href="/examples/nlp/multiple_choice_task_with_transfer_learning/">MultipleChoice Task with Transfer Learning</a> <a class="nav-sublink2" href="/examples/nlp/question_answering/">Question Answering with Hugging Face Transformers</a> <a class="nav-sublink2" href="/examples/nlp/t5_hf_summarization/">Abstractive Summarization with Hugging Face Transformers</a> <a class="nav-sublink" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparam Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/nlp/'>Natural Language Processing</a> / English-to-Spanish translation with a sequence-to-sequence Transformer </div> <div class='k-content'> <h1 id="englishtospanish-translation-with-a-sequencetosequence-transformer">English-to-Spanish translation with a sequence-to-sequence Transformer</h1> <p><strong>Author:</strong> <a href="https://twitter.com/fchollet">fchollet</a><br> <strong>Date created:</strong> 2021/05/26<br> <strong>Last modified:</strong> 2023/02/25<br> <strong>Description:</strong> Implementing a sequence-to-sequence Transformer and training it on a machine translation task.</p> <div class='example_version_banner keras_3'>ⓘ This example uses Keras 3</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/nlp/ipynb/neural_machine_translation_with_transformer.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/examples/nlp/neural_machine_translation_with_transformer.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="introduction">Introduction</h2> <p>In this example, we'll build a sequence-to-sequence Transformer model, which we'll train on an English-to-Spanish machine translation task.</p> <p>You'll learn how to:</p> <ul> <li>Vectorize text using the Keras <code>TextVectorization</code> layer.</li> <li>Implement a <code>TransformerEncoder</code> layer, a <code>TransformerDecoder</code> layer, and a <code>PositionalEmbedding</code> layer.</li> <li>Prepare data for training a sequence-to-sequence model.</li> <li>Use the trained model to generate translations of never-seen-before input sentences (sequence-to-sequence inference).</li> </ul> <p>The code featured here is adapted from the book <a href="https://www.manning.com/books/deep-learning-with-python-second-edition">Deep Learning with Python, Second Edition</a> (chapter 11: Deep learning for text). The present example is fairly barebones, so for detailed explanations of how each building block works, as well as the theory behind Transformers, I recommend reading the book.</p> <hr /> <h2 id="setup">Setup</h2> <div class="codehilite"><pre><span></span><code><span class="c1"># We set the backend to TensorFlow. The code works with</span> <span class="c1"># both `tensorflow` and `torch`. It does not work with JAX</span> <span class="c1"># due to the behavior of `jax.numpy.tile` in a jit scope</span> <span class="c1"># (used in `TransformerDecoder.get_causal_attention_mask()`:</span> <span class="c1"># `tile` in JAX does not support a dynamic `reps` argument.</span> <span class="c1"># You can make the code work in JAX by wrapping the</span> <span class="c1"># inside of the `get_causal_attention_mask` method in</span> <span class="c1"># a decorator to prevent jit compilation:</span> <span class="c1"># `with jax.ensure_compile_time_eval():`.</span> <span class="kn">import</span> <span class="nn">os</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s2">"KERAS_BACKEND"</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"tensorflow"</span> <span class="kn">import</span> <span class="nn">pathlib</span> <span class="kn">import</span> <span class="nn">random</span> <span class="kn">import</span> <span class="nn">string</span> <span class="kn">import</span> <span class="nn">re</span> <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> <span class="kn">import</span> <span class="nn">tensorflow.data</span> <span class="k">as</span> <span class="nn">tf_data</span> <span class="kn">import</span> <span class="nn">tensorflow.strings</span> <span class="k">as</span> <span class="nn">tf_strings</span> <span class="kn">import</span> <span class="nn">keras</span> <span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">layers</span> <span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">ops</span> <span class="kn">from</span> <span class="nn">keras.layers</span> <span class="kn">import</span> <span class="n">TextVectorization</span> </code></pre></div> <hr /> <h2 id="downloading-the-data">Downloading the data</h2> <p>We'll be working with an English-to-Spanish translation dataset provided by <a href="https://www.manythings.org/anki/">Anki</a>. Let's download it:</p> <div class="codehilite"><pre><span></span><code><span class="n">text_file</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">get_file</span><span class="p">(</span> <span class="n">fname</span><span class="o">=</span><span class="s2">"spa-eng.zip"</span><span class="p">,</span> <span class="n">origin</span><span class="o">=</span><span class="s2">"http://storage.googleapis.com/download.tensorflow.org/data/spa-eng.zip"</span><span class="p">,</span> <span class="n">extract</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="p">)</span> <span class="n">text_file</span> <span class="o">=</span> <span class="n">pathlib</span><span class="o">.</span><span class="n">Path</span><span class="p">(</span><span class="n">text_file</span><span class="p">)</span><span class="o">.</span><span class="n">parent</span> <span class="o">/</span> <span class="s2">"spa-eng"</span> <span class="o">/</span> <span class="s2">"spa.txt"</span> </code></pre></div> <hr /> <h2 id="parsing-the-data">Parsing the data</h2> <p>Each line contains an English sentence and its corresponding Spanish sentence. The English sentence is the <em>source sequence</em> and Spanish one is the <em>target sequence</em>. We prepend the token <code>"[start]"</code> and we append the token <code>"[end]"</code> to the Spanish sentence.</p> <div class="codehilite"><pre><span></span><code><span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">text_file</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span> <span class="n">lines</span> <span class="o">=</span> <span class="n">f</span><span class="o">.</span><span class="n">read</span><span class="p">()</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">"</span><span class="se">\n</span><span class="s2">"</span><span class="p">)[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="n">text_pairs</span> <span class="o">=</span> <span class="p">[]</span> <span class="k">for</span> <span class="n">line</span> <span class="ow">in</span> <span class="n">lines</span><span class="p">:</span> <span class="n">eng</span><span class="p">,</span> <span class="n">spa</span> <span class="o">=</span> <span class="n">line</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">"</span><span class="se">\t</span><span class="s2">"</span><span class="p">)</span> <span class="n">spa</span> <span class="o">=</span> <span class="s2">"[start] "</span> <span class="o">+</span> <span class="n">spa</span> <span class="o">+</span> <span class="s2">" [end]"</span> <span class="n">text_pairs</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">eng</span><span class="p">,</span> <span class="n">spa</span><span class="p">))</span> </code></pre></div> <p>Here's what our sentence pairs look like:</p> <div class="codehilite"><pre><span></span><code><span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">5</span><span class="p">):</span> <span class="nb">print</span><span class="p">(</span><span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="n">text_pairs</span><span class="p">))</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>("On Saturday nights, it's difficult to find parking around here.", '[start] Los sábados por la noche es difícil encontrar aparcamiento por aquí. [end]') ('I was the worst student in the class.', '[start] Fui el peor estudiante en la clase. [end]') ('There is nothing to do today.', '[start] No hay nada que hacer hoy. [end]') ('The twins do resemble each other.', '[start] Los gemelos se parecen mutuamente. [end]') ('They found Tom in the crowd.', '[start] Encontraron a Tom entre la multitud. [end]') </code></pre></div> </div> <p>Now, let's split the sentence pairs into a training set, a validation set, and a test set.</p> <div class="codehilite"><pre><span></span><code><span class="n">random</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="n">text_pairs</span><span class="p">)</span> <span class="n">num_val_samples</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="mf">0.15</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="n">text_pairs</span><span class="p">))</span> <span class="n">num_train_samples</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">text_pairs</span><span class="p">)</span> <span class="o">-</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">num_val_samples</span> <span class="n">train_pairs</span> <span class="o">=</span> <span class="n">text_pairs</span><span class="p">[:</span><span class="n">num_train_samples</span><span class="p">]</span> <span class="n">val_pairs</span> <span class="o">=</span> <span class="n">text_pairs</span><span class="p">[</span><span class="n">num_train_samples</span> <span class="p">:</span> <span class="n">num_train_samples</span> <span class="o">+</span> <span class="n">num_val_samples</span><span class="p">]</span> <span class="n">test_pairs</span> <span class="o">=</span> <span class="n">text_pairs</span><span class="p">[</span><span class="n">num_train_samples</span> <span class="o">+</span> <span class="n">num_val_samples</span> <span class="p">:]</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">text_pairs</span><span class="p">)</span><span class="si">}</span><span class="s2"> total pairs"</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">train_pairs</span><span class="p">)</span><span class="si">}</span><span class="s2"> training pairs"</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">val_pairs</span><span class="p">)</span><span class="si">}</span><span class="s2"> validation pairs"</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">test_pairs</span><span class="p">)</span><span class="si">}</span><span class="s2"> test pairs"</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>118964 total pairs 83276 training pairs 17844 validation pairs 17844 test pairs </code></pre></div> </div> <hr /> <h2 id="vectorizing-the-text-data">Vectorizing the text data</h2> <p>We'll use two instances of the <code>TextVectorization</code> layer to vectorize the text data (one for English and one for Spanish), that is to say, to turn the original strings into integer sequences where each integer represents the index of a word in a vocabulary.</p> <p>The English layer will use the default string standardization (strip punctuation characters) and splitting scheme (split on whitespace), while the Spanish layer will use a custom standardization, where we add the character <code>"¿"</code> to the set of punctuation characters to be stripped.</p> <p>Note: in a production-grade machine translation model, I would not recommend stripping the punctuation characters in either language. Instead, I would recommend turning each punctuation character into its own token, which you could achieve by providing a custom <code>split</code> function to the <code>TextVectorization</code> layer.</p> <div class="codehilite"><pre><span></span><code><span class="n">strip_chars</span> <span class="o">=</span> <span class="n">string</span><span class="o">.</span><span class="n">punctuation</span> <span class="o">+</span> <span class="s2">"¿"</span> <span class="n">strip_chars</span> <span class="o">=</span> <span class="n">strip_chars</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">"["</span><span class="p">,</span> <span class="s2">""</span><span class="p">)</span> <span class="n">strip_chars</span> <span class="o">=</span> <span class="n">strip_chars</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">"]"</span><span class="p">,</span> <span class="s2">""</span><span class="p">)</span> <span class="n">vocab_size</span> <span class="o">=</span> <span class="mi">15000</span> <span class="n">sequence_length</span> <span class="o">=</span> <span class="mi">20</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="mi">64</span> <span class="k">def</span> <span class="nf">custom_standardization</span><span class="p">(</span><span class="n">input_string</span><span class="p">):</span> <span class="n">lowercase</span> <span class="o">=</span> <span class="n">tf_strings</span><span class="o">.</span><span class="n">lower</span><span class="p">(</span><span class="n">input_string</span><span class="p">)</span> <span class="k">return</span> <span class="n">tf_strings</span><span class="o">.</span><span class="n">regex_replace</span><span class="p">(</span><span class="n">lowercase</span><span class="p">,</span> <span class="s2">"[</span><span class="si">%s</span><span class="s2">]"</span> <span class="o">%</span> <span class="n">re</span><span class="o">.</span><span class="n">escape</span><span class="p">(</span><span class="n">strip_chars</span><span class="p">),</span> <span class="s2">""</span><span class="p">)</span> <span class="n">eng_vectorization</span> <span class="o">=</span> <span class="n">TextVectorization</span><span class="p">(</span> <span class="n">max_tokens</span><span class="o">=</span><span class="n">vocab_size</span><span class="p">,</span> <span class="n">output_mode</span><span class="o">=</span><span class="s2">"int"</span><span class="p">,</span> <span class="n">output_sequence_length</span><span class="o">=</span><span class="n">sequence_length</span><span class="p">,</span> <span class="p">)</span> <span class="n">spa_vectorization</span> <span class="o">=</span> <span class="n">TextVectorization</span><span class="p">(</span> <span class="n">max_tokens</span><span class="o">=</span><span class="n">vocab_size</span><span class="p">,</span> <span class="n">output_mode</span><span class="o">=</span><span class="s2">"int"</span><span class="p">,</span> <span class="n">output_sequence_length</span><span class="o">=</span><span class="n">sequence_length</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">standardize</span><span class="o">=</span><span class="n">custom_standardization</span><span class="p">,</span> <span class="p">)</span> <span class="n">train_eng_texts</span> <span class="o">=</span> <span class="p">[</span><span class="n">pair</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="k">for</span> <span class="n">pair</span> <span class="ow">in</span> <span class="n">train_pairs</span><span class="p">]</span> <span class="n">train_spa_texts</span> <span class="o">=</span> <span class="p">[</span><span class="n">pair</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="k">for</span> <span class="n">pair</span> <span class="ow">in</span> <span class="n">train_pairs</span><span class="p">]</span> <span class="n">eng_vectorization</span><span class="o">.</span><span class="n">adapt</span><span class="p">(</span><span class="n">train_eng_texts</span><span class="p">)</span> <span class="n">spa_vectorization</span><span class="o">.</span><span class="n">adapt</span><span class="p">(</span><span class="n">train_spa_texts</span><span class="p">)</span> </code></pre></div> <p>Next, we'll format our datasets.</p> <p>At each training step, the model will seek to predict target words N+1 (and beyond) using the source sentence and the target words 0 to N.</p> <p>As such, the training dataset will yield a tuple <code>(inputs, targets)</code>, where:</p> <ul> <li><code>inputs</code> is a dictionary with the keys <code>encoder_inputs</code> and <code>decoder_inputs</code>. <code>encoder_inputs</code> is the vectorized source sentence and <code>encoder_inputs</code> is the target sentence "so far", that is to say, the words 0 to N used to predict word N+1 (and beyond) in the target sentence.</li> <li><code>target</code> is the target sentence offset by one step: it provides the next words in the target sentence – what the model will try to predict.</li> </ul> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">format_dataset</span><span class="p">(</span><span class="n">eng</span><span class="p">,</span> <span class="n">spa</span><span class="p">):</span> <span class="n">eng</span> <span class="o">=</span> <span class="n">eng_vectorization</span><span class="p">(</span><span class="n">eng</span><span class="p">)</span> <span class="n">spa</span> <span class="o">=</span> <span class="n">spa_vectorization</span><span class="p">(</span><span class="n">spa</span><span class="p">)</span> <span class="k">return</span> <span class="p">(</span> <span class="p">{</span> <span class="s2">"encoder_inputs"</span><span class="p">:</span> <span class="n">eng</span><span class="p">,</span> <span class="s2">"decoder_inputs"</span><span class="p">:</span> <span class="n">spa</span><span class="p">[:,</span> <span class="p">:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="p">},</span> <span class="n">spa</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">:],</span> <span class="p">)</span> <span class="k">def</span> <span class="nf">make_dataset</span><span class="p">(</span><span class="n">pairs</span><span class="p">):</span> <span class="n">eng_texts</span><span class="p">,</span> <span class="n">spa_texts</span> <span class="o">=</span> <span class="nb">zip</span><span class="p">(</span><span class="o">*</span><span class="n">pairs</span><span class="p">)</span> <span class="n">eng_texts</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">eng_texts</span><span class="p">)</span> <span class="n">spa_texts</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">spa_texts</span><span class="p">)</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">tf_data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">from_tensor_slices</span><span class="p">((</span><span class="n">eng_texts</span><span class="p">,</span> <span class="n">spa_texts</span><span class="p">))</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">format_dataset</span><span class="p">)</span> <span class="k">return</span> <span class="n">dataset</span><span class="o">.</span><span class="n">cache</span><span class="p">()</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="mi">2048</span><span class="p">)</span><span class="o">.</span><span class="n">prefetch</span><span class="p">(</span><span class="mi">16</span><span class="p">)</span> <span class="n">train_ds</span> <span class="o">=</span> <span class="n">make_dataset</span><span class="p">(</span><span class="n">train_pairs</span><span class="p">)</span> <span class="n">val_ds</span> <span class="o">=</span> <span class="n">make_dataset</span><span class="p">(</span><span class="n">val_pairs</span><span class="p">)</span> </code></pre></div> <p>Let's take a quick look at the sequence shapes (we have batches of 64 pairs, and all sequences are 20 steps long):</p> <div class="codehilite"><pre><span></span><code><span class="k">for</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span> <span class="ow">in</span> <span class="n">train_ds</span><span class="o">.</span><span class="n">take</span><span class="p">(</span><span class="mi">1</span><span class="p">):</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s1">'inputs["encoder_inputs"].shape: </span><span class="si">{</span><span class="n">inputs</span><span class="p">[</span><span class="s2">"encoder_inputs"</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s1">'inputs["decoder_inputs"].shape: </span><span class="si">{</span><span class="n">inputs</span><span class="p">[</span><span class="s2">"decoder_inputs"</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"targets.shape: </span><span class="si">{</span><span class="n">targets</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>inputs["encoder_inputs"].shape: (64, 20) inputs["decoder_inputs"].shape: (64, 20) targets.shape: (64, 20) </code></pre></div> </div> <hr /> <h2 id="building-the-model">Building the model</h2> <p>Our sequence-to-sequence Transformer consists of a <code>TransformerEncoder</code> and a <code>TransformerDecoder</code> chained together. To make the model aware of word order, we also use a <code>PositionalEmbedding</code> layer.</p> <p>The source sequence will be pass to the <code>TransformerEncoder</code>, which will produce a new representation of it. This new representation will then be passed to the <code>TransformerDecoder</code>, together with the target sequence so far (target words 0 to N). The <code>TransformerDecoder</code> will then seek to predict the next words in the target sequence (N+1 and beyond).</p> <p>A key detail that makes this possible is causal masking (see method <code>get_causal_attention_mask()</code> on the <code>TransformerDecoder</code>). The <code>TransformerDecoder</code> sees the entire sequences at once, and thus we must make sure that it only uses information from target tokens 0 to N when predicting token N+1 (otherwise, it could use information from the future, which would result in a model that cannot be used at inference time).</p> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">keras.ops</span> <span class="k">as</span> <span class="nn">ops</span> <span class="k">class</span> <span class="nc">TransformerEncoder</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">,</span> <span class="n">dense_dim</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">embed_dim</span> <span class="o">=</span> <span class="n">embed_dim</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense_dim</span> <span class="o">=</span> <span class="n">dense_dim</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_heads</span> <span class="o">=</span> <span class="n">num_heads</span> <span class="bp">self</span><span class="o">.</span><span class="n">attention</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">MultiHeadAttention</span><span class="p">(</span> <span class="n">num_heads</span><span class="o">=</span><span class="n">num_heads</span><span class="p">,</span> <span class="n">key_dim</span><span class="o">=</span><span class="n">embed_dim</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense_proj</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span> <span class="p">[</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">dense_dim</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">embed_dim</span><span class="p">),</span> <span class="p">]</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">layernorm_1</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">LayerNormalization</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">layernorm_2</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">LayerNormalization</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">supports_masking</span> <span class="o">=</span> <span class="kc">True</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="k">if</span> <span class="n">mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> <span class="n">padding_mask</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">mask</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">,</span> <span class="p">:],</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"int32"</span><span class="p">)</span> <span class="k">else</span><span class="p">:</span> <span class="n">padding_mask</span> <span class="o">=</span> <span class="kc">None</span> <span class="n">attention_output</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attention</span><span class="p">(</span> <span class="n">query</span><span class="o">=</span><span class="n">inputs</span><span class="p">,</span> <span class="n">value</span><span class="o">=</span><span class="n">inputs</span><span class="p">,</span> <span class="n">key</span><span class="o">=</span><span class="n">inputs</span><span class="p">,</span> <span class="n">attention_mask</span><span class="o">=</span><span class="n">padding_mask</span> <span class="p">)</span> <span class="n">proj_input</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">layernorm_1</span><span class="p">(</span><span class="n">inputs</span> <span class="o">+</span> <span class="n">attention_output</span><span class="p">)</span> <span class="n">proj_output</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense_proj</span><span class="p">(</span><span class="n">proj_input</span><span class="p">)</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">layernorm_2</span><span class="p">(</span><span class="n">proj_input</span> <span class="o">+</span> <span class="n">proj_output</span><span class="p">)</span> <span class="k">def</span> <span class="nf">get_config</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="n">config</span> <span class="o">=</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">get_config</span><span class="p">()</span> <span class="n">config</span><span class="o">.</span><span class="n">update</span><span class="p">(</span> <span class="p">{</span> <span class="s2">"embed_dim"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">embed_dim</span><span class="p">,</span> <span class="s2">"dense_dim"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense_dim</span><span class="p">,</span> <span class="s2">"num_heads"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_heads</span><span class="p">,</span> <span class="p">}</span> <span class="p">)</span> <span class="k">return</span> <span class="n">config</span> <span class="k">class</span> <span class="nc">PositionalEmbedding</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sequence_length</span><span class="p">,</span> <span class="n">vocab_size</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">token_embeddings</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span> <span class="n">input_dim</span><span class="o">=</span><span class="n">vocab_size</span><span class="p">,</span> <span class="n">output_dim</span><span class="o">=</span><span class="n">embed_dim</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">position_embeddings</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span> <span class="n">input_dim</span><span class="o">=</span><span class="n">sequence_length</span><span class="p">,</span> <span class="n">output_dim</span><span class="o">=</span><span class="n">embed_dim</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">sequence_length</span> <span class="o">=</span> <span class="n">sequence_length</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size</span> <span class="o">=</span> <span class="n">vocab_size</span> <span class="bp">self</span><span class="o">.</span><span class="n">embed_dim</span> <span class="o">=</span> <span class="n">embed_dim</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="n">length</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">inputs</span><span class="p">)[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="n">positions</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">length</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="n">embedded_tokens</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">token_embeddings</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">embedded_positions</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">position_embeddings</span><span class="p">(</span><span class="n">positions</span><span class="p">)</span> <span class="k">return</span> <span class="n">embedded_tokens</span> <span class="o">+</span> <span class="n">embedded_positions</span> <span class="k">def</span> <span class="nf">compute_mask</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="k">if</span> <span class="n">mask</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> <span class="k">return</span> <span class="kc">None</span> <span class="k">else</span><span class="p">:</span> <span class="k">return</span> <span class="n">ops</span><span class="o">.</span><span class="n">not_equal</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span> <span class="k">def</span> <span class="nf">get_config</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="n">config</span> <span class="o">=</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">get_config</span><span class="p">()</span> <span class="n">config</span><span class="o">.</span><span class="n">update</span><span class="p">(</span> <span class="p">{</span> <span class="s2">"sequence_length"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">sequence_length</span><span class="p">,</span> <span class="s2">"vocab_size"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab_size</span><span class="p">,</span> <span class="s2">"embed_dim"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">embed_dim</span><span class="p">,</span> <span class="p">}</span> <span class="p">)</span> <span class="k">return</span> <span class="n">config</span> <span class="k">class</span> <span class="nc">TransformerDecoder</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">,</span> <span class="n">latent_dim</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">embed_dim</span> <span class="o">=</span> <span class="n">embed_dim</span> <span class="bp">self</span><span class="o">.</span><span class="n">latent_dim</span> <span class="o">=</span> <span class="n">latent_dim</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_heads</span> <span class="o">=</span> <span class="n">num_heads</span> <span class="bp">self</span><span class="o">.</span><span class="n">attention_1</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">MultiHeadAttention</span><span class="p">(</span> <span class="n">num_heads</span><span class="o">=</span><span class="n">num_heads</span><span class="p">,</span> <span class="n">key_dim</span><span class="o">=</span><span class="n">embed_dim</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">attention_2</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">MultiHeadAttention</span><span class="p">(</span> <span class="n">num_heads</span><span class="o">=</span><span class="n">num_heads</span><span class="p">,</span> <span class="n">key_dim</span><span class="o">=</span><span class="n">embed_dim</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense_proj</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span> <span class="p">[</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">latent_dim</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">embed_dim</span><span class="p">),</span> <span class="p">]</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">layernorm_1</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">LayerNormalization</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">layernorm_2</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">LayerNormalization</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">layernorm_3</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">LayerNormalization</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">supports_masking</span> <span class="o">=</span> <span class="kc">True</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">encoder_outputs</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="n">causal_mask</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_causal_attention_mask</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="k">if</span> <span class="n">mask</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> <span class="n">padding_mask</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">mask</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">,</span> <span class="p">:],</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"int32"</span><span class="p">)</span> <span class="n">padding_mask</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">minimum</span><span class="p">(</span><span class="n">padding_mask</span><span class="p">,</span> <span class="n">causal_mask</span><span class="p">)</span> <span class="k">else</span><span class="p">:</span> <span class="n">padding_mask</span> <span class="o">=</span> <span class="kc">None</span> <span class="n">attention_output_1</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attention_1</span><span class="p">(</span> <span class="n">query</span><span class="o">=</span><span class="n">inputs</span><span class="p">,</span> <span class="n">value</span><span class="o">=</span><span class="n">inputs</span><span class="p">,</span> <span class="n">key</span><span class="o">=</span><span class="n">inputs</span><span class="p">,</span> <span class="n">attention_mask</span><span class="o">=</span><span class="n">causal_mask</span> <span class="p">)</span> <span class="n">out_1</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">layernorm_1</span><span class="p">(</span><span class="n">inputs</span> <span class="o">+</span> <span class="n">attention_output_1</span><span class="p">)</span> <span class="n">attention_output_2</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attention_2</span><span class="p">(</span> <span class="n">query</span><span class="o">=</span><span class="n">out_1</span><span class="p">,</span> <span class="n">value</span><span class="o">=</span><span class="n">encoder_outputs</span><span class="p">,</span> <span class="n">key</span><span class="o">=</span><span class="n">encoder_outputs</span><span class="p">,</span> <span class="n">attention_mask</span><span class="o">=</span><span class="n">padding_mask</span><span class="p">,</span> <span class="p">)</span> <span class="n">out_2</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">layernorm_2</span><span class="p">(</span><span class="n">out_1</span> <span class="o">+</span> <span class="n">attention_output_2</span><span class="p">)</span> <span class="n">proj_output</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense_proj</span><span class="p">(</span><span class="n">out_2</span><span class="p">)</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">layernorm_3</span><span class="p">(</span><span class="n">out_2</span> <span class="o">+</span> <span class="n">proj_output</span><span class="p">)</span> <span class="k">def</span> <span class="nf">get_causal_attention_mask</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="n">input_shape</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">sequence_length</span> <span class="o">=</span> <span class="n">input_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">input_shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="n">i</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">sequence_length</span><span class="p">)[:,</span> <span class="kc">None</span><span class="p">]</span> <span class="n">j</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">sequence_length</span><span class="p">)</span> <span class="n">mask</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">i</span> <span class="o">>=</span> <span class="n">j</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"int32"</span><span class="p">)</span> <span class="n">mask</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">input_shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">input_shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]))</span> <span class="n">mult</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">concatenate</span><span class="p">(</span> <span class="p">[</span><span class="n">ops</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">),</span> <span class="n">ops</span><span class="o">.</span><span class="n">convert_to_tensor</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">])],</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="p">)</span> <span class="k">return</span> <span class="n">ops</span><span class="o">.</span><span class="n">tile</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">mult</span><span class="p">)</span> <span class="k">def</span> <span class="nf">get_config</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="n">config</span> <span class="o">=</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">get_config</span><span class="p">()</span> <span class="n">config</span><span class="o">.</span><span class="n">update</span><span class="p">(</span> <span class="p">{</span> <span class="s2">"embed_dim"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">embed_dim</span><span class="p">,</span> <span class="s2">"latent_dim"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">latent_dim</span><span class="p">,</span> <span class="s2">"num_heads"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_heads</span><span class="p">,</span> <span class="p">}</span> <span class="p">)</span> <span class="k">return</span> <span class="n">config</span> </code></pre></div> <p>Next, we assemble the end-to-end model.</p> <div class="codehilite"><pre><span></span><code><span class="n">embed_dim</span> <span class="o">=</span> <span class="mi">256</span> <span class="n">latent_dim</span> <span class="o">=</span> <span class="mi">2048</span> <span class="n">num_heads</span> <span class="o">=</span> <span class="mi">8</span> <span class="n">encoder_inputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="kc">None</span><span class="p">,),</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"int64"</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"encoder_inputs"</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">PositionalEmbedding</span><span class="p">(</span><span class="n">sequence_length</span><span class="p">,</span> <span class="n">vocab_size</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">)(</span><span class="n">encoder_inputs</span><span class="p">)</span> <span class="n">encoder_outputs</span> <span class="o">=</span> <span class="n">TransformerEncoder</span><span class="p">(</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">latent_dim</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">encoder</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">encoder_inputs</span><span class="p">,</span> <span class="n">encoder_outputs</span><span class="p">)</span> <span class="n">decoder_inputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="kc">None</span><span class="p">,),</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"int64"</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"decoder_inputs"</span><span class="p">)</span> <span class="n">encoded_seq_inputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="kc">None</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"decoder_state_inputs"</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">PositionalEmbedding</span><span class="p">(</span><span class="n">sequence_length</span><span class="p">,</span> <span class="n">vocab_size</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">)(</span><span class="n">decoder_inputs</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">TransformerDecoder</span><span class="p">(</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">latent_dim</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">)(</span><span class="n">x</span><span class="p">,</span> <span class="n">encoded_seq_inputs</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="mf">0.5</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">decoder_outputs</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">vocab_size</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"softmax"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">decoder</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">([</span><span class="n">decoder_inputs</span><span class="p">,</span> <span class="n">encoded_seq_inputs</span><span class="p">],</span> <span class="n">decoder_outputs</span><span class="p">)</span> <span class="n">decoder_outputs</span> <span class="o">=</span> <span class="n">decoder</span><span class="p">([</span><span class="n">decoder_inputs</span><span class="p">,</span> <span class="n">encoder_outputs</span><span class="p">])</span> <span class="n">transformer</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span> <span class="p">[</span><span class="n">encoder_inputs</span><span class="p">,</span> <span class="n">decoder_inputs</span><span class="p">],</span> <span class="n">decoder_outputs</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"transformer"</span> <span class="p">)</span> </code></pre></div> <hr /> <h2 id="training-our-model">Training our model</h2> <p>We'll use accuracy as a quick way to monitor training progress on the validation data. Note that machine translation typically uses BLEU scores as well as other metrics, rather than accuracy.</p> <p>Here we only train for 1 epoch, but to get the model to actually converge you should train for at least 30 epochs.</p> <div class="codehilite"><pre><span></span><code><span class="n">epochs</span> <span class="o">=</span> <span class="mi">1</span> <span class="c1"># This should be at least 30 for convergence</span> <span class="n">transformer</span><span class="o">.</span><span class="n">summary</span><span class="p">()</span> <span class="n">transformer</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="s2">"rmsprop"</span><span class="p">,</span> <span class="n">loss</span><span class="o">=</span><span class="s2">"sparse_categorical_crossentropy"</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="s2">"accuracy"</span><span class="p">]</span> <span class="p">)</span> <span class="n">transformer</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">train_ds</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="n">epochs</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="n">val_ds</span><span class="p">)</span> </code></pre></div> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold">Model: "transformer"</span> </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace">┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓ ┃<span style="font-weight: bold"> Layer (type) </span>┃<span style="font-weight: bold"> Output Shape </span>┃<span style="font-weight: bold"> Param # </span>┃<span style="font-weight: bold"> Connected to </span>┃ ┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩ │ encoder_inputs │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ - │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">InputLayer</span>) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ positional_embeddi… │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">256</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">3,845,…</span> │ encoder_inputs[<span style="color: #00af00; text-decoration-color: #00af00">0</span>][<span style="color: #00af00; text-decoration-color: #00af00">0</span>] │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">PositionalEmbeddi…</span> │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ decoder_inputs │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ - │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">InputLayer</span>) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ transformer_encoder │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">256</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">3,155,…</span> │ positional_embeddin… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">TransformerEncode…</span> │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ functional_5 │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, │ <span style="color: #00af00; text-decoration-color: #00af00">12,959…</span> │ decoder_inputs[<span style="color: #00af00; text-decoration-color: #00af00">0</span>][<span style="color: #00af00; text-decoration-color: #00af00">0</span>… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Functional</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">15000</span>) │ │ transformer_encoder… │ └─────────────────────┴───────────────────┴─────────┴──────────────────────┘ </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Total params: </span><span style="color: #00af00; text-decoration-color: #00af00">19,960,216</span> (76.14 MB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">19,960,216</span> (76.14 MB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Non-trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">0</span> (0.00 B) </pre> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 5/1302 [37m━━━━━━━━━━━━━━━━━━━━ 42s 33ms/step - accuracy: 0.3558 - loss: 8.3596 WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1699484373.932513 76082 device_compiler.h:187] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process. 1302/1302 ━━━━━━━━━━━━━━━━━━━━ 64s 39ms/step - accuracy: 0.7073 - loss: 2.2372 - val_accuracy: 0.7329 - val_loss: 1.6477 <keras.src.callbacks.history.History at 0x7ff611f21540> </code></pre></div> </div> <hr /> <h2 id="decoding-test-sentences">Decoding test sentences</h2> <p>Finally, let's demonstrate how to translate brand new English sentences. We simply feed into the model the vectorized English sentence as well as the target token <code>"[start]"</code>, then we repeatedly generated the next token, until we hit the token <code>"[end]"</code>.</p> <div class="codehilite"><pre><span></span><code><span class="n">spa_vocab</span> <span class="o">=</span> <span class="n">spa_vectorization</span><span class="o">.</span><span class="n">get_vocabulary</span><span class="p">()</span> <span class="n">spa_index_lookup</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">spa_vocab</span><span class="p">)),</span> <span class="n">spa_vocab</span><span class="p">))</span> <span class="n">max_decoded_sentence_length</span> <span class="o">=</span> <span class="mi">20</span> <span class="k">def</span> <span class="nf">decode_sequence</span><span class="p">(</span><span class="n">input_sentence</span><span class="p">):</span> <span class="n">tokenized_input_sentence</span> <span class="o">=</span> <span class="n">eng_vectorization</span><span class="p">([</span><span class="n">input_sentence</span><span class="p">])</span> <span class="n">decoded_sentence</span> <span class="o">=</span> <span class="s2">"[start]"</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">max_decoded_sentence_length</span><span class="p">):</span> <span class="n">tokenized_target_sentence</span> <span class="o">=</span> <span class="n">spa_vectorization</span><span class="p">([</span><span class="n">decoded_sentence</span><span class="p">])[:,</span> <span class="p">:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">transformer</span><span class="p">([</span><span class="n">tokenized_input_sentence</span><span class="p">,</span> <span class="n">tokenized_target_sentence</span><span class="p">])</span> <span class="c1"># ops.argmax(predictions[0, i, :]) is not a concrete value for jax here</span> <span class="n">sampled_token_index</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">convert_to_numpy</span><span class="p">(</span> <span class="n">ops</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">predictions</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="n">i</span><span class="p">,</span> <span class="p">:])</span> <span class="p">)</span><span class="o">.</span><span class="n">item</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="n">sampled_token</span> <span class="o">=</span> <span class="n">spa_index_lookup</span><span class="p">[</span><span class="n">sampled_token_index</span><span class="p">]</span> <span class="n">decoded_sentence</span> <span class="o">+=</span> <span class="s2">" "</span> <span class="o">+</span> <span class="n">sampled_token</span> <span class="k">if</span> <span class="n">sampled_token</span> <span class="o">==</span> <span class="s2">"[end]"</span><span class="p">:</span> <span class="k">break</span> <span class="k">return</span> <span class="n">decoded_sentence</span> <span class="n">test_eng_texts</span> <span class="o">=</span> <span class="p">[</span><span class="n">pair</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="k">for</span> <span class="n">pair</span> <span class="ow">in</span> <span class="n">test_pairs</span><span class="p">]</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">30</span><span class="p">):</span> <span class="n">input_sentence</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="n">test_eng_texts</span><span class="p">)</span> <span class="n">translated</span> <span class="o">=</span> <span class="n">decode_sequence</span><span class="p">(</span><span class="n">input_sentence</span><span class="p">)</span> </code></pre></div> <p>After 30 epochs, we get results such as:</p> <blockquote> <p>She handed him the money. [start] ella le pasó el dinero [end]</p> <p>Tom has never heard Mary sing. [start] tom nunca ha oído cantar a mary [end]</p> <p>Perhaps she will come tomorrow. [start] tal vez ella vendrá mañana [end]</p> <p>I love to write. [start] me encanta escribir [end]</p> <p>His French is improving little by little. [start] su francés va a [UNK] sólo un poco [end]</p> <p>My hotel told me to call you. [start] mi hotel me dijo que te [UNK] [end]</p> </blockquote> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#englishtospanish-translation-with-a-sequencetosequence-transformer'>English-to-Spanish translation with a sequence-to-sequence Transformer</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#setup'>Setup</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#downloading-the-data'>Downloading the data</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#parsing-the-data'>Parsing the data</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#vectorizing-the-text-data'>Vectorizing the text data</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#building-the-model'>Building the model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#training-our-model'>Training our model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#decoding-test-sentences'>Decoding test sentences</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>