CINXE.COM

Text classification with Switch Transformer

<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/nlp/text_classification_with_switch_transformer/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Text classification with Switch Transformer"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Text classification with Switch Transformer"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Text classification with Switch Transformer</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink active" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink2" href="/examples/nlp/text_classification_from_scratch/">Text classification from scratch</a> <a class="nav-sublink2" href="/examples/nlp/active_learning_review_classification/">Review Classification using Active Learning</a> <a class="nav-sublink2" href="/examples/nlp/fnet_classification_with_keras_hub/">Text Classification using FNet</a> <a class="nav-sublink2" href="/examples/nlp/multi_label_classification/">Large-scale multi-label text classification</a> <a class="nav-sublink2" href="/examples/nlp/text_classification_with_transformer/">Text classification with Transformer</a> <a class="nav-sublink2 active" href="/examples/nlp/text_classification_with_switch_transformer/">Text classification with Switch Transformer</a> <a class="nav-sublink2" href="/examples/nlp/tweet-classification-using-tfdf/">Text classification using Decision Forests and pretrained embeddings</a> <a class="nav-sublink2" href="/examples/nlp/pretrained_word_embeddings/">Using pre-trained word embeddings</a> <a class="nav-sublink2" href="/examples/nlp/bidirectional_lstm_imdb/">Bidirectional LSTM on IMDB</a> <a class="nav-sublink2" href="/examples/nlp/data_parallel_training_with_keras_hub/">Data Parallel Training with KerasHub and tf.distribute</a> <a class="nav-sublink2" href="/examples/nlp/neural_machine_translation_with_keras_hub/">English-to-Spanish translation with KerasHub</a> <a class="nav-sublink2" href="/examples/nlp/neural_machine_translation_with_transformer/">English-to-Spanish translation with a sequence-to-sequence Transformer</a> <a class="nav-sublink2" href="/examples/nlp/lstm_seq2seq/">Character-level recurrent sequence-to-sequence model</a> <a class="nav-sublink2" href="/examples/nlp/multimodal_entailment/">Multimodal entailment</a> <a class="nav-sublink2" href="/examples/nlp/ner_transformers/">Named Entity Recognition using Transformers</a> <a class="nav-sublink2" href="/examples/nlp/text_extraction_with_bert/">Text Extraction with BERT</a> <a class="nav-sublink2" href="/examples/nlp/addition_rnn/">Sequence to sequence learning for performing number addition</a> <a class="nav-sublink2" href="/examples/nlp/semantic_similarity_with_keras_hub/">Semantic Similarity with KerasHub</a> <a class="nav-sublink2" href="/examples/nlp/semantic_similarity_with_bert/">Semantic Similarity with BERT</a> <a class="nav-sublink2" href="/examples/nlp/sentence_embeddings_with_sbert/">Sentence embeddings using Siamese RoBERTa-networks</a> <a class="nav-sublink2" href="/examples/nlp/masked_language_modeling/">End-to-end Masked Language Modeling with BERT</a> <a class="nav-sublink2" href="/examples/nlp/abstractive_summarization_with_bart/">Abstractive Text Summarization with BART</a> <a class="nav-sublink2" href="/examples/nlp/pretraining_BERT/">Pretraining BERT with Hugging Face Transformers</a> <a class="nav-sublink2" href="/examples/nlp/parameter_efficient_finetuning_of_gpt2_with_lora/">Parameter-efficient fine-tuning of GPT-2 with LoRA</a> <a class="nav-sublink2" href="/examples/nlp/mlm_training_tpus/">Training a language model from scratch with 🤗 Transformers and TPUs</a> <a class="nav-sublink2" href="/examples/nlp/multiple_choice_task_with_transfer_learning/">MultipleChoice Task with Transfer Learning</a> <a class="nav-sublink2" href="/examples/nlp/question_answering/">Question Answering with Hugging Face Transformers</a> <a class="nav-sublink2" href="/examples/nlp/t5_hf_summarization/">Abstractive Summarization with Hugging Face Transformers</a> <a class="nav-sublink" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparameter Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> <a class="nav-link" href="/keras_cv/" role="tab" aria-selected="">KerasCV: Computer Vision Workflows</a> <a class="nav-link" href="/keras_nlp/" role="tab" aria-selected="">KerasNLP: Natural Language Workflows</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/nlp/'>Natural Language Processing</a> / Text classification with Switch Transformer </div> <div class='k-content'> <h1 id="text-classification-with-switch-transformer">Text classification with Switch Transformer</h1> <p><strong>Author:</strong> <a href="https://www.linkedin.com/in/khalid-salama-24403144/">Khalid Salama</a><br> <strong>Date created:</strong> 2020/05/10<br> <strong>Last modified:</strong> 2021/02/15<br> <strong>Description:</strong> Implement a Switch Transformer for text classification.</p> <div class='example_version_banner keras_3'>ⓘ This example uses Keras 3</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/nlp/ipynb/text_classification_with_switch_transformer.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/examples/nlp/text_classification_with_switch_transformer.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="introduction">Introduction</h2> <p>This example demonstrates the implementation of the <a href="https://arxiv.org/abs/2101.03961">Switch Transformer</a> model for text classification.</p> <p>The Switch Transformer replaces the feedforward network (FFN) layer in the standard Transformer with a Mixture of Expert (MoE) routing layer, where each expert operates independently on the tokens in the sequence. This allows increasing the model size without increasing the computation needed to process each example.</p> <p>Note that, for training the Switch Transformer efficiently, data and model parallelism need to be applied, so that expert modules can run simultaneously, each on its own accelerator. While the implementation described in the paper uses the <a href="https://github.com/tensorflow/mesh">TensorFlow Mesh</a> framework for distributed training, this example presents a simple, non-distributed implementation of the Switch Transformer model for demonstration purposes.</p> <hr /> <h2 id="setup">Setup</h2> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">keras</span> <span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">ops</span> <span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">layers</span> </code></pre></div> <hr /> <h2 id="download-and-prepare-dataset">Download and prepare dataset</h2> <div class="codehilite"><pre><span></span><code><span class="n">vocab_size</span> <span class="o">=</span> <span class="mi">20000</span> <span class="c1"># Only consider the top 20k words</span> <span class="n">num_tokens_per_example</span> <span class="o">=</span> <span class="mi">200</span> <span class="c1"># Only consider the first 200 words of each movie review</span> <span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">),</span> <span class="p">(</span><span class="n">x_val</span><span class="p">,</span> <span class="n">y_val</span><span class="p">)</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">datasets</span><span class="o">.</span><span class="n">imdb</span><span class="o">.</span><span class="n">load_data</span><span class="p">(</span><span class="n">num_words</span><span class="o">=</span><span class="n">vocab_size</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">x_train</span><span class="p">),</span> <span class="s2">&quot;Training sequences&quot;</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">x_val</span><span class="p">),</span> <span class="s2">&quot;Validation sequences&quot;</span><span class="p">)</span> <span class="n">x_train</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">pad_sequences</span><span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="n">maxlen</span><span class="o">=</span><span class="n">num_tokens_per_example</span><span class="p">)</span> <span class="n">x_val</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">pad_sequences</span><span class="p">(</span><span class="n">x_val</span><span class="p">,</span> <span class="n">maxlen</span><span class="o">=</span><span class="n">num_tokens_per_example</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>25000 Training sequences 25000 Validation sequences </code></pre></div> </div> <hr /> <h2 id="define-hyperparameters">Define hyperparameters</h2> <div class="codehilite"><pre><span></span><code><span class="n">embed_dim</span> <span class="o">=</span> <span class="mi">32</span> <span class="c1"># Embedding size for each token.</span> <span class="n">num_heads</span> <span class="o">=</span> <span class="mi">2</span> <span class="c1"># Number of attention heads</span> <span class="n">ff_dim</span> <span class="o">=</span> <span class="mi">32</span> <span class="c1"># Hidden layer size in feedforward network.</span> <span class="n">num_experts</span> <span class="o">=</span> <span class="mi">10</span> <span class="c1"># Number of experts used in the Switch Transformer.</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="mi">50</span> <span class="c1"># Batch size.</span> <span class="n">learning_rate</span> <span class="o">=</span> <span class="mf">0.001</span> <span class="c1"># Learning rate.</span> <span class="n">dropout_rate</span> <span class="o">=</span> <span class="mf">0.25</span> <span class="c1"># Dropout rate.</span> <span class="n">num_epochs</span> <span class="o">=</span> <span class="mi">3</span> <span class="c1"># Number of epochs.</span> <span class="n">num_tokens_per_batch</span> <span class="o">=</span> <span class="p">(</span> <span class="n">batch_size</span> <span class="o">*</span> <span class="n">num_tokens_per_example</span> <span class="p">)</span> <span class="c1"># Total number of tokens per batch.</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Number of tokens per batch: </span><span class="si">{</span><span class="n">num_tokens_per_batch</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Number of tokens per batch: 10000 </code></pre></div> </div> <hr /> <h2 id="implement-token-amp-position-embedding-layer">Implement token &amp; position embedding layer</h2> <p>It consists of two separate embedding layers, one for tokens, one for token index (positions).</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">TokenAndPositionEmbedding</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">maxlen</span><span class="p">,</span> <span class="n">vocab_size</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">token_emb</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="n">input_dim</span><span class="o">=</span><span class="n">vocab_size</span><span class="p">,</span> <span class="n">output_dim</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">pos_emb</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="n">input_dim</span><span class="o">=</span><span class="n">maxlen</span><span class="p">,</span> <span class="n">output_dim</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="n">maxlen</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">x</span><span class="p">)[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="n">positions</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">start</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">stop</span><span class="o">=</span><span class="n">maxlen</span><span class="p">,</span> <span class="n">step</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="n">positions</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pos_emb</span><span class="p">(</span><span class="n">positions</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">token_emb</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">x</span> <span class="o">+</span> <span class="n">positions</span> </code></pre></div> <hr /> <h2 id="implement-the-feedforward-network">Implement the feedforward network</h2> <p>This is used as the Mixture of Experts in the Switch Transformer.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">create_feedforward_network</span><span class="p">(</span><span class="n">ff_dim</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="k">return</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span> <span class="p">[</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">ff_dim</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">&quot;relu&quot;</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">embed_dim</span><span class="p">)],</span> <span class="n">name</span><span class="o">=</span><span class="n">name</span> <span class="p">)</span> </code></pre></div> <hr /> <h2 id="implement-the-loadbalanced-loss">Implement the load-balanced loss</h2> <p>This is an auxiliary loss to encourage a balanced load across experts.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">load_balanced_loss</span><span class="p">(</span><span class="n">router_probs</span><span class="p">,</span> <span class="n">expert_mask</span><span class="p">):</span> <span class="c1"># router_probs [tokens_per_batch, num_experts] is the probability assigned for</span> <span class="c1"># each expert per token. expert_mask [tokens_per_batch, num_experts] contains</span> <span class="c1"># the expert with the highest router probability in one−hot format.</span> <span class="n">num_experts</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">expert_mask</span><span class="p">)[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="c1"># Get the fraction of tokens routed to each expert.</span> <span class="c1"># density is a vector of length num experts that sums to 1.</span> <span class="n">density</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">expert_mask</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="c1"># Get fraction of probability mass assigned to each expert from the router</span> <span class="c1"># across all tokens. density_proxy is a vector of length num experts that sums to 1.</span> <span class="n">density_proxy</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">router_probs</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="c1"># Want both vectors to have uniform allocation (1/num experts) across all</span> <span class="c1"># num_expert elements. The two vectors will be pushed towards uniform allocation</span> <span class="c1"># when the dot product is minimized.</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">density_proxy</span> <span class="o">*</span> <span class="n">density</span><span class="p">)</span> <span class="o">*</span> <span class="n">ops</span><span class="o">.</span><span class="n">cast</span><span class="p">((</span><span class="n">num_experts</span><span class="o">**</span><span class="mi">2</span><span class="p">),</span> <span class="s2">&quot;float32&quot;</span><span class="p">)</span> <span class="k">return</span> <span class="n">loss</span> </code></pre></div> <h3 id="implement-the-router-as-a-layer">Implement the router as a layer</h3> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">Router</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">num_experts</span><span class="p">,</span> <span class="n">expert_capacity</span><span class="p">):</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_experts</span> <span class="o">=</span> <span class="n">num_experts</span> <span class="bp">self</span><span class="o">.</span><span class="n">route</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">units</span><span class="o">=</span><span class="n">num_experts</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">expert_capacity</span> <span class="o">=</span> <span class="n">expert_capacity</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span> <span class="c1"># inputs shape: [tokens_per_batch, embed_dim]</span> <span class="c1"># router_logits shape: [tokens_per_batch, num_experts]</span> <span class="n">router_logits</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">route</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="k">if</span> <span class="n">training</span><span class="p">:</span> <span class="c1"># Add noise for exploration across experts.</span> <span class="n">router_logits</span> <span class="o">+=</span> <span class="n">keras</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="n">router_logits</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">minval</span><span class="o">=</span><span class="mf">0.9</span><span class="p">,</span> <span class="n">maxval</span><span class="o">=</span><span class="mf">1.1</span> <span class="p">)</span> <span class="c1"># Probabilities for each token of what expert it should be sent to.</span> <span class="n">router_probs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">activations</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">router_logits</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># Get the top−1 expert for each token. expert_gate is the top−1 probability</span> <span class="c1"># from the router for each token. expert_index is what expert each token</span> <span class="c1"># is going to be routed to.</span> <span class="n">expert_gate</span><span class="p">,</span> <span class="n">expert_index</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">top_k</span><span class="p">(</span><span class="n">router_probs</span><span class="p">,</span> <span class="n">k</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># expert_mask shape: [tokens_per_batch, num_experts]</span> <span class="n">expert_mask</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">one_hot</span><span class="p">(</span><span class="n">expert_index</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_experts</span><span class="p">)</span> <span class="c1"># Compute load balancing loss.</span> <span class="n">aux_loss</span> <span class="o">=</span> <span class="n">load_balanced_loss</span><span class="p">(</span><span class="n">router_probs</span><span class="p">,</span> <span class="n">expert_mask</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">add_loss</span><span class="p">(</span><span class="n">aux_loss</span><span class="p">)</span> <span class="c1"># Experts have a fixed capacity, ensure we do not exceed it. Construct</span> <span class="c1"># the batch indices, to each expert, with position in expert make sure that</span> <span class="c1"># not more that expert capacity examples can be routed to each expert.</span> <span class="n">position_in_expert</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span> <span class="n">ops</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">expert_mask</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="o">*</span> <span class="n">expert_mask</span><span class="p">,</span> <span class="s2">&quot;int32&quot;</span> <span class="p">)</span> <span class="c1"># Keep only tokens that fit within expert capacity.</span> <span class="n">expert_mask</span> <span class="o">*=</span> <span class="n">ops</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span> <span class="n">ops</span><span class="o">.</span><span class="n">less</span><span class="p">(</span><span class="n">ops</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">position_in_expert</span><span class="p">,</span> <span class="s2">&quot;int32&quot;</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">expert_capacity</span><span class="p">),</span> <span class="s2">&quot;float32&quot;</span><span class="p">,</span> <span class="p">)</span> <span class="n">expert_mask_flat</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">expert_mask</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># Mask out the experts that have overflowed the expert capacity.</span> <span class="n">expert_gate</span> <span class="o">*=</span> <span class="n">expert_mask_flat</span> <span class="c1"># Combine expert outputs and scaling with router probability.</span> <span class="c1"># combine_tensor shape: [tokens_per_batch, num_experts, expert_capacity]</span> <span class="n">combined_tensor</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span> <span class="n">expert_gate</span> <span class="o">*</span> <span class="n">expert_mask_flat</span> <span class="o">*</span> <span class="n">ops</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="n">ops</span><span class="o">.</span><span class="n">one_hot</span><span class="p">(</span><span class="n">expert_index</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_experts</span><span class="p">),</span> <span class="mi">1</span><span class="p">),</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="p">)</span> <span class="o">*</span> <span class="n">ops</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="n">ops</span><span class="o">.</span><span class="n">one_hot</span><span class="p">(</span><span class="n">position_in_expert</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">expert_capacity</span><span class="p">),</span> <span class="mi">1</span><span class="p">)</span> <span class="c1"># Create binary dispatch_tensor [tokens_per_batch, num_experts, expert_capacity]</span> <span class="c1"># that is 1 if the token gets routed to the corresponding expert.</span> <span class="n">dispatch_tensor</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">combined_tensor</span><span class="p">,</span> <span class="s2">&quot;float32&quot;</span><span class="p">)</span> <span class="k">return</span> <span class="n">dispatch_tensor</span><span class="p">,</span> <span class="n">combined_tensor</span> </code></pre></div> <h3 id="implement-a-switch-layer">Implement a Switch layer</h3> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">Switch</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> <span class="bp">self</span><span class="p">,</span> <span class="n">num_experts</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">,</span> <span class="n">ff_dim</span><span class="p">,</span> <span class="n">num_tokens_per_batch</span><span class="p">,</span> <span class="n">capacity_factor</span><span class="o">=</span><span class="mi">1</span> <span class="p">):</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_experts</span> <span class="o">=</span> <span class="n">num_experts</span> <span class="bp">self</span><span class="o">.</span><span class="n">embed_dim</span> <span class="o">=</span> <span class="n">embed_dim</span> <span class="bp">self</span><span class="o">.</span><span class="n">experts</span> <span class="o">=</span> <span class="p">[</span> <span class="n">create_feedforward_network</span><span class="p">(</span><span class="n">ff_dim</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_experts</span><span class="p">)</span> <span class="p">]</span> <span class="bp">self</span><span class="o">.</span><span class="n">expert_capacity</span> <span class="o">=</span> <span class="n">num_tokens_per_batch</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_experts</span> <span class="bp">self</span><span class="o">.</span><span class="n">router</span> <span class="o">=</span> <span class="n">Router</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_experts</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">expert_capacity</span><span class="p">)</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">inputs</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span> <span class="n">num_tokens_per_example</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">inputs</span><span class="p">)[</span><span class="mi">1</span><span class="p">]</span> <span class="c1"># inputs shape: [num_tokens_per_batch, embed_dim]</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="p">[</span><span class="n">num_tokens_per_batch</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">embed_dim</span><span class="p">])</span> <span class="c1"># dispatch_tensor shape: [expert_capacity, num_experts, tokens_per_batch]</span> <span class="c1"># combine_tensor shape: [tokens_per_batch, num_experts, expert_capacity]</span> <span class="n">dispatch_tensor</span><span class="p">,</span> <span class="n">combine_tensor</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">router</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="c1"># expert_inputs shape: [num_experts, expert_capacity, embed_dim]</span> <span class="n">expert_inputs</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s2">&quot;ab,acd-&gt;cdb&quot;</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">dispatch_tensor</span><span class="p">)</span> <span class="n">expert_inputs</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span> <span class="n">expert_inputs</span><span class="p">,</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">num_experts</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">expert_capacity</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">embed_dim</span><span class="p">]</span> <span class="p">)</span> <span class="c1"># Dispatch to experts</span> <span class="n">expert_input_list</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">unstack</span><span class="p">(</span><span class="n">expert_inputs</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="n">expert_output_list</span> <span class="o">=</span> <span class="p">[</span> <span class="bp">self</span><span class="o">.</span><span class="n">experts</span><span class="p">[</span><span class="n">idx</span><span class="p">](</span><span class="n">expert_input</span><span class="p">)</span> <span class="k">for</span> <span class="n">idx</span><span class="p">,</span> <span class="n">expert_input</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">expert_input_list</span><span class="p">)</span> <span class="p">]</span> <span class="c1"># expert_outputs shape: [expert_capacity, num_experts, embed_dim]</span> <span class="n">expert_outputs</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">expert_output_list</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># expert_outputs_combined shape: [tokens_per_batch, embed_dim]</span> <span class="n">expert_outputs_combined</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span> <span class="s2">&quot;abc,xba-&gt;xc&quot;</span><span class="p">,</span> <span class="n">expert_outputs</span><span class="p">,</span> <span class="n">combine_tensor</span> <span class="p">)</span> <span class="c1"># output shape: [batch_size, num_tokens_per_example, embed_dim]</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span> <span class="n">expert_outputs_combined</span><span class="p">,</span> <span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">num_tokens_per_example</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">embed_dim</span><span class="p">],</span> <span class="p">)</span> <span class="k">return</span> <span class="n">outputs</span> </code></pre></div> <hr /> <h2 id="implement-a-transformer-block-layer">Implement a Transformer block layer</h2> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">TransformerBlock</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">,</span> <span class="n">ffn</span><span class="p">,</span> <span class="n">dropout_rate</span><span class="o">=</span><span class="mf">0.1</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">att</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">MultiHeadAttention</span><span class="p">(</span><span class="n">num_heads</span><span class="o">=</span><span class="n">num_heads</span><span class="p">,</span> <span class="n">key_dim</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">)</span> <span class="c1"># The ffn can be either a standard feedforward network or a switch</span> <span class="c1"># layer with a Mixture of Experts.</span> <span class="bp">self</span><span class="o">.</span><span class="n">ffn</span> <span class="o">=</span> <span class="n">ffn</span> <span class="bp">self</span><span class="o">.</span><span class="n">layernorm1</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">LayerNormalization</span><span class="p">(</span><span class="n">epsilon</span><span class="o">=</span><span class="mf">1e-6</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">layernorm2</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">LayerNormalization</span><span class="p">(</span><span class="n">epsilon</span><span class="o">=</span><span class="mf">1e-6</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout1</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout_rate</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout2</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout_rate</span><span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span> <span class="n">attn_output</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">att</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">inputs</span><span class="p">)</span> <span class="n">attn_output</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout1</span><span class="p">(</span><span class="n">attn_output</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="n">training</span><span class="p">)</span> <span class="n">out1</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">layernorm1</span><span class="p">(</span><span class="n">inputs</span> <span class="o">+</span> <span class="n">attn_output</span><span class="p">)</span> <span class="n">ffn_output</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ffn</span><span class="p">(</span><span class="n">out1</span><span class="p">)</span> <span class="n">ffn_output</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout2</span><span class="p">(</span><span class="n">ffn_output</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="n">training</span><span class="p">)</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">layernorm2</span><span class="p">(</span><span class="n">out1</span> <span class="o">+</span> <span class="n">ffn_output</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="implement-the-classifier">Implement the classifier</h2> <p>The <code>TransformerBlock</code> layer outputs one vector for each time step of our input sequence. Here, we take the mean across all time steps and use a feedforward network on top of it to classify text.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">create_classifier</span><span class="p">():</span> <span class="n">switch</span> <span class="o">=</span> <span class="n">Switch</span><span class="p">(</span><span class="n">num_experts</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">,</span> <span class="n">ff_dim</span><span class="p">,</span> <span class="n">num_tokens_per_batch</span><span class="p">)</span> <span class="n">transformer_block</span> <span class="o">=</span> <span class="n">TransformerBlock</span><span class="p">(</span><span class="n">embed_dim</span> <span class="o">//</span> <span class="n">num_heads</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">,</span> <span class="n">switch</span><span class="p">)</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">num_tokens_per_example</span><span class="p">,))</span> <span class="n">embedding_layer</span> <span class="o">=</span> <span class="n">TokenAndPositionEmbedding</span><span class="p">(</span> <span class="n">num_tokens_per_example</span><span class="p">,</span> <span class="n">vocab_size</span><span class="p">,</span> <span class="n">embed_dim</span> <span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">embedding_layer</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">transformer_block</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">GlobalAveragePooling1D</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout_rate</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">ff_dim</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">&quot;relu&quot;</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout_rate</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">&quot;softmax&quot;</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">classifier</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="o">=</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="o">=</span><span class="n">outputs</span><span class="p">)</span> <span class="k">return</span> <span class="n">classifier</span> </code></pre></div> <hr /> <h2 id="train-and-evaluate-the-model">Train and evaluate the model</h2> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">run_experiment</span><span class="p">(</span><span class="n">classifier</span><span class="p">):</span> <span class="n">classifier</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">learning_rate</span><span class="p">),</span> <span class="n">loss</span><span class="o">=</span><span class="s2">&quot;sparse_categorical_crossentropy&quot;</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;accuracy&quot;</span><span class="p">],</span> <span class="p">)</span> <span class="n">history</span> <span class="o">=</span> <span class="n">classifier</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span> <span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="n">num_epochs</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="p">(</span><span class="n">x_val</span><span class="p">,</span> <span class="n">y_val</span><span class="p">),</span> <span class="p">)</span> <span class="k">return</span> <span class="n">history</span> <span class="n">classifier</span> <span class="o">=</span> <span class="n">create_classifier</span><span class="p">()</span> <span class="n">run_experiment</span><span class="p">(</span><span class="n">classifier</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 1/3 500/500 ━━━━━━━━━━━━━━━━━━━━ 251s 485ms/step - accuracy: 0.7121 - loss: 1.5394 - val_accuracy: 0.8748 - val_loss: 1.2891 Epoch 2/3 500/500 ━━━━━━━━━━━━━━━━━━━━ 240s 480ms/step - accuracy: 0.9243 - loss: 1.2063 - val_accuracy: 0.8752 - val_loss: 1.3090 Epoch 3/3 500/500 ━━━━━━━━━━━━━━━━━━━━ 242s 485ms/step - accuracy: 0.9572 - loss: 1.1222 - val_accuracy: 0.8614 - val_loss: 1.3744 &lt;keras.src.callbacks.history.History at 0x7efb79d82a90&gt; </code></pre></div> </div> <hr /> <h2 id="conclusion">Conclusion</h2> <p>Compared to the standard Transformer architecture, the Switch Transformer can have a much larger number of parameters, leading to increased model capacity, while maintaining a reasonable computational cost.</p> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#text-classification-with-switch-transformer'>Text classification with Switch Transformer</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#setup'>Setup</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#download-and-prepare-dataset'>Download and prepare dataset</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#define-hyperparameters'>Define hyperparameters</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#implement-token-amp-position-embedding-layer'>Implement token & position embedding layer</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#implement-the-feedforward-network'>Implement the feedforward network</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#implement-the-loadbalanced-loss'>Implement the load-balanced loss</a> </div> <div class='k-outline-depth-3'> <a href='#implement-the-router-as-a-layer'>Implement the router as a layer</a> </div> <div class='k-outline-depth-3'> <a href='#implement-a-switch-layer'>Implement a Switch layer</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#implement-a-transformer-block-layer'>Implement a Transformer block layer</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#implement-the-classifier'>Implement the classifier</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#train-and-evaluate-the-model'>Train and evaluate the model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#conclusion'>Conclusion</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>

Pages: 1 2 3 4 5 6 7 8 9 10