CINXE.COM
Semantic Similarity with BERT
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/nlp/semantic_similarity_with_bert/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Semantic Similarity with BERT"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Semantic Similarity with BERT"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Semantic Similarity with BERT</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink active" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink2" href="/examples/nlp/text_classification_from_scratch/">Text classification from scratch</a> <a class="nav-sublink2" href="/examples/nlp/active_learning_review_classification/">Review Classification using Active Learning</a> <a class="nav-sublink2" href="/examples/nlp/fnet_classification_with_keras_hub/">Text Classification using FNet</a> <a class="nav-sublink2" href="/examples/nlp/multi_label_classification/">Large-scale multi-label text classification</a> <a class="nav-sublink2" href="/examples/nlp/text_classification_with_transformer/">Text classification with Transformer</a> <a class="nav-sublink2" href="/examples/nlp/text_classification_with_switch_transformer/">Text classification with Switch Transformer</a> <a class="nav-sublink2" href="/examples/nlp/tweet-classification-using-tfdf/">Text classification using Decision Forests and pretrained embeddings</a> <a class="nav-sublink2" href="/examples/nlp/pretrained_word_embeddings/">Using pre-trained word embeddings</a> <a class="nav-sublink2" href="/examples/nlp/bidirectional_lstm_imdb/">Bidirectional LSTM on IMDB</a> <a class="nav-sublink2" href="/examples/nlp/data_parallel_training_with_keras_hub/">Data Parallel Training with KerasHub and tf.distribute</a> <a class="nav-sublink2" href="/examples/nlp/neural_machine_translation_with_keras_hub/">English-to-Spanish translation with KerasHub</a> <a class="nav-sublink2" href="/examples/nlp/neural_machine_translation_with_transformer/">English-to-Spanish translation with a sequence-to-sequence Transformer</a> <a class="nav-sublink2" href="/examples/nlp/lstm_seq2seq/">Character-level recurrent sequence-to-sequence model</a> <a class="nav-sublink2" href="/examples/nlp/multimodal_entailment/">Multimodal entailment</a> <a class="nav-sublink2" href="/examples/nlp/ner_transformers/">Named Entity Recognition using Transformers</a> <a class="nav-sublink2" href="/examples/nlp/text_extraction_with_bert/">Text Extraction with BERT</a> <a class="nav-sublink2" href="/examples/nlp/addition_rnn/">Sequence to sequence learning for performing number addition</a> <a class="nav-sublink2" href="/examples/nlp/semantic_similarity_with_keras_hub/">Semantic Similarity with KerasHub</a> <a class="nav-sublink2 active" href="/examples/nlp/semantic_similarity_with_bert/">Semantic Similarity with BERT</a> <a class="nav-sublink2" href="/examples/nlp/sentence_embeddings_with_sbert/">Sentence embeddings using Siamese RoBERTa-networks</a> <a class="nav-sublink2" href="/examples/nlp/masked_language_modeling/">End-to-end Masked Language Modeling with BERT</a> <a class="nav-sublink2" href="/examples/nlp/abstractive_summarization_with_bart/">Abstractive Text Summarization with BART</a> <a class="nav-sublink2" href="/examples/nlp/pretraining_BERT/">Pretraining BERT with Hugging Face Transformers</a> <a class="nav-sublink2" href="/examples/nlp/parameter_efficient_finetuning_of_gpt2_with_lora/">Parameter-efficient fine-tuning of GPT-2 with LoRA</a> <a class="nav-sublink2" href="/examples/nlp/mlm_training_tpus/">Training a language model from scratch with 🤗 Transformers and TPUs</a> <a class="nav-sublink2" href="/examples/nlp/multiple_choice_task_with_transfer_learning/">MultipleChoice Task with Transfer Learning</a> <a class="nav-sublink2" href="/examples/nlp/question_answering/">Question Answering with Hugging Face Transformers</a> <a class="nav-sublink2" href="/examples/nlp/t5_hf_summarization/">Abstractive Summarization with Hugging Face Transformers</a> <a class="nav-sublink" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparameter Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> <a class="nav-link" href="/keras_cv/" role="tab" aria-selected="">KerasCV: Computer Vision Workflows</a> <a class="nav-link" href="/keras_nlp/" role="tab" aria-selected="">KerasNLP: Natural Language Workflows</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/nlp/'>Natural Language Processing</a> / Semantic Similarity with BERT </div> <div class='k-content'> <h1 id="semantic-similarity-with-bert">Semantic Similarity with BERT</h1> <p><strong>Author:</strong> <a href="https://twitter.com/mohmadmerchant1">Mohamad Merchant</a><br> <strong>Date created:</strong> 2020/08/15<br> <strong>Last modified:</strong> 2020/08/29<br> <strong>Description:</strong> Natural Language Inference by fine-tuning BERT model on SNLI Corpus.</p> <div class='example_version_banner keras_3'>ⓘ This example uses Keras 3</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/nlp/ipynb/semantic_similarity_with_bert.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/examples/nlp/semantic_similarity_with_bert.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="introduction">Introduction</h2> <p>Semantic Similarity is the task of determining how similar two sentences are, in terms of what they mean. This example demonstrates the use of SNLI (Stanford Natural Language Inference) Corpus to predict sentence semantic similarity with Transformers. We will fine-tune a BERT model that takes two sentences as inputs and that outputs a similarity score for these two sentences.</p> <h3 id="references">References</h3> <ul> <li><a href="https://arxiv.org/pdf/1810.04805.pdf">BERT</a></li> <li><a href="https://nlp.stanford.edu/projects/snli/">SNLI</a></li> </ul> <hr /> <h2 id="setup">Setup</h2> <p>Note: install HuggingFace <code>transformers</code> via <code>pip install transformers</code> (version >= 2.11.0).</p> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> <span class="kn">import</span> <span class="nn">pandas</span> <span class="k">as</span> <span class="nn">pd</span> <span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="nn">tf</span> <span class="kn">import</span> <span class="nn">transformers</span> </code></pre></div> <hr /> <h2 id="configuration">Configuration</h2> <div class="codehilite"><pre><span></span><code><span class="n">max_length</span> <span class="o">=</span> <span class="mi">128</span> <span class="c1"># Maximum length of input sentence to the model.</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="mi">32</span> <span class="n">epochs</span> <span class="o">=</span> <span class="mi">2</span> <span class="c1"># Labels in our dataset.</span> <span class="n">labels</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"contradiction"</span><span class="p">,</span> <span class="s2">"entailment"</span><span class="p">,</span> <span class="s2">"neutral"</span><span class="p">]</span> </code></pre></div> <hr /> <h2 id="load-the-data">Load the Data</h2> <div class="codehilite"><pre><span></span><code><span class="err">!</span><span class="n">curl</span> <span class="o">-</span><span class="n">LO</span> <span class="n">https</span><span class="p">:</span><span class="o">//</span><span class="n">raw</span><span class="o">.</span><span class="n">githubusercontent</span><span class="o">.</span><span class="n">com</span><span class="o">/</span><span class="n">MohamadMerchant</span><span class="o">/</span><span class="n">SNLI</span><span class="o">/</span><span class="n">master</span><span class="o">/</span><span class="n">data</span><span class="o">.</span><span class="n">tar</span><span class="o">.</span><span class="n">gz</span> <span class="err">!</span><span class="n">tar</span> <span class="o">-</span><span class="n">xvzf</span> <span class="n">data</span><span class="o">.</span><span class="n">tar</span><span class="o">.</span><span class="n">gz</span> </code></pre></div> <div class="codehilite"><pre><span></span><code><span class="c1"># There are more than 550k samples in total; we will use 100k for this example.</span> <span class="n">train_df</span> <span class="o">=</span> <span class="n">pd</span><span class="o">.</span><span class="n">read_csv</span><span class="p">(</span><span class="s2">"SNLI_Corpus/snli_1.0_train.csv"</span><span class="p">,</span> <span class="n">nrows</span><span class="o">=</span><span class="mi">100000</span><span class="p">)</span> <span class="n">valid_df</span> <span class="o">=</span> <span class="n">pd</span><span class="o">.</span><span class="n">read_csv</span><span class="p">(</span><span class="s2">"SNLI_Corpus/snli_1.0_dev.csv"</span><span class="p">)</span> <span class="n">test_df</span> <span class="o">=</span> <span class="n">pd</span><span class="o">.</span><span class="n">read_csv</span><span class="p">(</span><span class="s2">"SNLI_Corpus/snli_1.0_test.csv"</span><span class="p">)</span> <span class="c1"># Shape of the data</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Total train samples : </span><span class="si">{</span><span class="n">train_df</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Total validation samples: </span><span class="si">{</span><span class="n">valid_df</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Total test samples: </span><span class="si">{</span><span class="n">valid_df</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> % Total % Received % Xferd Average Speed Time Time Time Current Dload Upload Total Spent Left Speed 100 11.1M 100 11.1M 0 0 5231k 0 0:00:02 0:00:02 --:--:-- 5231k SNLI_Corpus/ SNLI_Corpus/snli_1.0_dev.csv SNLI_Corpus/snli_1.0_train.csv SNLI_Corpus/snli_1.0_test.csv Total train samples : 100000 Total validation samples: 10000 Total test samples: 10000 </code></pre></div> </div> <p>Dataset Overview:</p> <ul> <li>sentence1: The premise caption that was supplied to the author of the pair.</li> <li>sentence2: The hypothesis caption that was written by the author of the pair.</li> <li>similarity: This is the label chosen by the majority of annotators. Where no majority exists, the label "-" is used (we will skip such samples here).</li> </ul> <p>Here are the "similarity" label values in our dataset:</p> <ul> <li>Contradiction: The sentences share no similarity.</li> <li>Entailment: The sentences have similar meaning.</li> <li>Neutral: The sentences are neutral.</li> </ul> <p>Let's look at one sample from the dataset:</p> <div class="codehilite"><pre><span></span><code><span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Sentence1: </span><span class="si">{</span><span class="n">train_df</span><span class="o">.</span><span class="n">loc</span><span class="p">[</span><span class="mi">1</span><span class="p">,</span><span class="w"> </span><span class="s1">'sentence1'</span><span class="p">]</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Sentence2: </span><span class="si">{</span><span class="n">train_df</span><span class="o">.</span><span class="n">loc</span><span class="p">[</span><span class="mi">1</span><span class="p">,</span><span class="w"> </span><span class="s1">'sentence2'</span><span class="p">]</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Similarity: </span><span class="si">{</span><span class="n">train_df</span><span class="o">.</span><span class="n">loc</span><span class="p">[</span><span class="mi">1</span><span class="p">,</span><span class="w"> </span><span class="s1">'similarity'</span><span class="p">]</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Sentence1: A person on a horse jumps over a broken down airplane. Sentence2: A person is at a diner, ordering an omelette. Similarity: contradiction </code></pre></div> </div> <hr /> <h2 id="preprocessing">Preprocessing</h2> <div class="codehilite"><pre><span></span><code><span class="c1"># We have some NaN entries in our train data, we will simply drop them.</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Number of missing values"</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="n">train_df</span><span class="o">.</span><span class="n">isnull</span><span class="p">()</span><span class="o">.</span><span class="n">sum</span><span class="p">())</span> <span class="n">train_df</span><span class="o">.</span><span class="n">dropna</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">inplace</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Number of missing values similarity 0 sentence1 0 sentence2 3 dtype: int64 </code></pre></div> </div> <p>Distribution of our training targets.</p> <div class="codehilite"><pre><span></span><code><span class="nb">print</span><span class="p">(</span><span class="s2">"Train Target Distribution"</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="n">train_df</span><span class="o">.</span><span class="n">similarity</span><span class="o">.</span><span class="n">value_counts</span><span class="p">())</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Train Target Distribution entailment 33384 contradiction 33310 neutral 33193 - 110 Name: similarity, dtype: int64 </code></pre></div> </div> <p>Distribution of our validation targets.</p> <div class="codehilite"><pre><span></span><code><span class="nb">print</span><span class="p">(</span><span class="s2">"Validation Target Distribution"</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="n">valid_df</span><span class="o">.</span><span class="n">similarity</span><span class="o">.</span><span class="n">value_counts</span><span class="p">())</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Validation Target Distribution entailment 3329 contradiction 3278 neutral 3235 - 158 Name: similarity, dtype: int64 </code></pre></div> </div> <p>The value "-" appears as part of our training and validation targets. We will skip these samples.</p> <div class="codehilite"><pre><span></span><code><span class="n">train_df</span> <span class="o">=</span> <span class="p">(</span> <span class="n">train_df</span><span class="p">[</span><span class="n">train_df</span><span class="o">.</span><span class="n">similarity</span> <span class="o">!=</span> <span class="s2">"-"</span><span class="p">]</span> <span class="o">.</span><span class="n">sample</span><span class="p">(</span><span class="n">frac</span><span class="o">=</span><span class="mf">1.0</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="mi">42</span><span class="p">)</span> <span class="o">.</span><span class="n">reset_index</span><span class="p">(</span><span class="n">drop</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="p">)</span> <span class="n">valid_df</span> <span class="o">=</span> <span class="p">(</span> <span class="n">valid_df</span><span class="p">[</span><span class="n">valid_df</span><span class="o">.</span><span class="n">similarity</span> <span class="o">!=</span> <span class="s2">"-"</span><span class="p">]</span> <span class="o">.</span><span class="n">sample</span><span class="p">(</span><span class="n">frac</span><span class="o">=</span><span class="mf">1.0</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="mi">42</span><span class="p">)</span> <span class="o">.</span><span class="n">reset_index</span><span class="p">(</span><span class="n">drop</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="p">)</span> </code></pre></div> <p>One-hot encode training, validation, and test labels.</p> <div class="codehilite"><pre><span></span><code><span class="n">train_df</span><span class="p">[</span><span class="s2">"label"</span><span class="p">]</span> <span class="o">=</span> <span class="n">train_df</span><span class="p">[</span><span class="s2">"similarity"</span><span class="p">]</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span> <span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="mi">0</span> <span class="k">if</span> <span class="n">x</span> <span class="o">==</span> <span class="s2">"contradiction"</span> <span class="k">else</span> <span class="mi">1</span> <span class="k">if</span> <span class="n">x</span> <span class="o">==</span> <span class="s2">"entailment"</span> <span class="k">else</span> <span class="mi">2</span> <span class="p">)</span> <span class="n">y_train</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">to_categorical</span><span class="p">(</span><span class="n">train_df</span><span class="o">.</span><span class="n">label</span><span class="p">,</span> <span class="n">num_classes</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span> <span class="n">valid_df</span><span class="p">[</span><span class="s2">"label"</span><span class="p">]</span> <span class="o">=</span> <span class="n">valid_df</span><span class="p">[</span><span class="s2">"similarity"</span><span class="p">]</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span> <span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="mi">0</span> <span class="k">if</span> <span class="n">x</span> <span class="o">==</span> <span class="s2">"contradiction"</span> <span class="k">else</span> <span class="mi">1</span> <span class="k">if</span> <span class="n">x</span> <span class="o">==</span> <span class="s2">"entailment"</span> <span class="k">else</span> <span class="mi">2</span> <span class="p">)</span> <span class="n">y_val</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">to_categorical</span><span class="p">(</span><span class="n">valid_df</span><span class="o">.</span><span class="n">label</span><span class="p">,</span> <span class="n">num_classes</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span> <span class="n">test_df</span><span class="p">[</span><span class="s2">"label"</span><span class="p">]</span> <span class="o">=</span> <span class="n">test_df</span><span class="p">[</span><span class="s2">"similarity"</span><span class="p">]</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span> <span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="mi">0</span> <span class="k">if</span> <span class="n">x</span> <span class="o">==</span> <span class="s2">"contradiction"</span> <span class="k">else</span> <span class="mi">1</span> <span class="k">if</span> <span class="n">x</span> <span class="o">==</span> <span class="s2">"entailment"</span> <span class="k">else</span> <span class="mi">2</span> <span class="p">)</span> <span class="n">y_test</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">to_categorical</span><span class="p">(</span><span class="n">test_df</span><span class="o">.</span><span class="n">label</span><span class="p">,</span> <span class="n">num_classes</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="create-a-custom-data-generator">Create a custom data generator</h2> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">BertSemanticDataGenerator</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">Sequence</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Generates batches of data.</span> <span class="sd"> Args:</span> <span class="sd"> sentence_pairs: Array of premise and hypothesis input sentences.</span> <span class="sd"> labels: Array of labels.</span> <span class="sd"> batch_size: Integer batch size.</span> <span class="sd"> shuffle: boolean, whether to shuffle the data.</span> <span class="sd"> include_targets: boolean, whether to include the labels.</span> <span class="sd"> Returns:</span> <span class="sd"> Tuples `([input_ids, attention_mask, `token_type_ids], labels)`</span> <span class="sd"> (or just `[input_ids, attention_mask, `token_type_ids]`</span> <span class="sd"> if `include_targets=False`)</span> <span class="sd"> """</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> <span class="bp">self</span><span class="p">,</span> <span class="n">sentence_pairs</span><span class="p">,</span> <span class="n">labels</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">include_targets</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="p">):</span> <span class="bp">self</span><span class="o">.</span><span class="n">sentence_pairs</span> <span class="o">=</span> <span class="n">sentence_pairs</span> <span class="bp">self</span><span class="o">.</span><span class="n">labels</span> <span class="o">=</span> <span class="n">labels</span> <span class="bp">self</span><span class="o">.</span><span class="n">shuffle</span> <span class="o">=</span> <span class="n">shuffle</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">=</span> <span class="n">batch_size</span> <span class="bp">self</span><span class="o">.</span><span class="n">include_targets</span> <span class="o">=</span> <span class="n">include_targets</span> <span class="c1"># Load our BERT Tokenizer to encode the text.</span> <span class="c1"># We will use base-base-uncased pretrained model.</span> <span class="bp">self</span><span class="o">.</span><span class="n">tokenizer</span> <span class="o">=</span> <span class="n">transformers</span><span class="o">.</span><span class="n">BertTokenizer</span><span class="o">.</span><span class="n">from_pretrained</span><span class="p">(</span> <span class="s2">"bert-base-uncased"</span><span class="p">,</span> <span class="n">do_lower_case</span><span class="o">=</span><span class="kc">True</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">indexes</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">sentence_pairs</span><span class="p">))</span> <span class="bp">self</span><span class="o">.</span><span class="n">on_epoch_end</span><span class="p">()</span> <span class="k">def</span> <span class="fm">__len__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="c1"># Denotes the number of batches per epoch.</span> <span class="k">return</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">sentence_pairs</span><span class="p">)</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="k">def</span> <span class="fm">__getitem__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">idx</span><span class="p">):</span> <span class="c1"># Retrieves the batch of index.</span> <span class="n">indexes</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">indexes</span><span class="p">[</span><span class="n">idx</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="p">:</span> <span class="p">(</span><span class="n">idx</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span><span class="p">]</span> <span class="n">sentence_pairs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sentence_pairs</span><span class="p">[</span><span class="n">indexes</span><span class="p">]</span> <span class="c1"># With BERT tokenizer's batch_encode_plus batch of both the sentences are</span> <span class="c1"># encoded together and separated by [SEP] token.</span> <span class="n">encoded</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">tokenizer</span><span class="o">.</span><span class="n">batch_encode_plus</span><span class="p">(</span> <span class="n">sentence_pairs</span><span class="o">.</span><span class="n">tolist</span><span class="p">(),</span> <span class="n">add_special_tokens</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">max_length</span><span class="o">=</span><span class="n">max_length</span><span class="p">,</span> <span class="n">return_attention_mask</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">return_token_type_ids</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">pad_to_max_length</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">return_tensors</span><span class="o">=</span><span class="s2">"tf"</span><span class="p">,</span> <span class="p">)</span> <span class="c1"># Convert batch of encoded features to numpy array.</span> <span class="n">input_ids</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">encoded</span><span class="p">[</span><span class="s2">"input_ids"</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"int32"</span><span class="p">)</span> <span class="n">attention_masks</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">encoded</span><span class="p">[</span><span class="s2">"attention_mask"</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"int32"</span><span class="p">)</span> <span class="n">token_type_ids</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">encoded</span><span class="p">[</span><span class="s2">"token_type_ids"</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"int32"</span><span class="p">)</span> <span class="c1"># Set to true if data generator is used for training/validation.</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">include_targets</span><span class="p">:</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">labels</span><span class="p">[</span><span class="n">indexes</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"int32"</span><span class="p">)</span> <span class="k">return</span> <span class="p">[</span><span class="n">input_ids</span><span class="p">,</span> <span class="n">attention_masks</span><span class="p">,</span> <span class="n">token_type_ids</span><span class="p">],</span> <span class="n">labels</span> <span class="k">else</span><span class="p">:</span> <span class="k">return</span> <span class="p">[</span><span class="n">input_ids</span><span class="p">,</span> <span class="n">attention_masks</span><span class="p">,</span> <span class="n">token_type_ids</span><span class="p">]</span> <span class="k">def</span> <span class="nf">on_epoch_end</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="c1"># Shuffle indexes after each epoch if shuffle is set to True.</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">shuffle</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">RandomState</span><span class="p">(</span><span class="mi">42</span><span class="p">)</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">indexes</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="build-the-model">Build the model</h2> <div class="codehilite"><pre><span></span><code><span class="c1"># Create the model under a distribution strategy scope.</span> <span class="n">strategy</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">distribute</span><span class="o">.</span><span class="n">MirroredStrategy</span><span class="p">()</span> <span class="k">with</span> <span class="n">strategy</span><span class="o">.</span><span class="n">scope</span><span class="p">():</span> <span class="c1"># Encoded token ids from BERT tokenizer.</span> <span class="n">input_ids</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">max_length</span><span class="p">,),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"input_ids"</span> <span class="p">)</span> <span class="c1"># Attention masks indicates to the model which tokens should be attended to.</span> <span class="n">attention_masks</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">max_length</span><span class="p">,),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"attention_masks"</span> <span class="p">)</span> <span class="c1"># Token type ids are binary masks identifying different sequences in the model.</span> <span class="n">token_type_ids</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">max_length</span><span class="p">,),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"token_type_ids"</span> <span class="p">)</span> <span class="c1"># Loading pretrained BERT model.</span> <span class="n">bert_model</span> <span class="o">=</span> <span class="n">transformers</span><span class="o">.</span><span class="n">TFBertModel</span><span class="o">.</span><span class="n">from_pretrained</span><span class="p">(</span><span class="s2">"bert-base-uncased"</span><span class="p">)</span> <span class="c1"># Freeze the BERT model to reuse the pretrained features without modifying them.</span> <span class="n">bert_model</span><span class="o">.</span><span class="n">trainable</span> <span class="o">=</span> <span class="kc">False</span> <span class="n">bert_output</span> <span class="o">=</span> <span class="n">bert_model</span><span class="o">.</span><span class="n">bert</span><span class="p">(</span> <span class="n">input_ids</span><span class="p">,</span> <span class="n">attention_mask</span><span class="o">=</span><span class="n">attention_masks</span><span class="p">,</span> <span class="n">token_type_ids</span><span class="o">=</span><span class="n">token_type_ids</span> <span class="p">)</span> <span class="n">sequence_output</span> <span class="o">=</span> <span class="n">bert_output</span><span class="o">.</span><span class="n">last_hidden_state</span> <span class="n">pooled_output</span> <span class="o">=</span> <span class="n">bert_output</span><span class="o">.</span><span class="n">pooler_output</span> <span class="c1"># Add trainable layers on top of frozen layers to adapt the pretrained features on the new data.</span> <span class="n">bi_lstm</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Bidirectional</span><span class="p">(</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">LSTM</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="n">return_sequences</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="p">)(</span><span class="n">sequence_output</span><span class="p">)</span> <span class="c1"># Applying hybrid pooling approach to bi_lstm sequence output.</span> <span class="n">avg_pool</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">GlobalAveragePooling1D</span><span class="p">()(</span><span class="n">bi_lstm</span><span class="p">)</span> <span class="n">max_pool</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">GlobalMaxPooling1D</span><span class="p">()(</span><span class="n">bi_lstm</span><span class="p">)</span> <span class="n">concat</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">concatenate</span><span class="p">([</span><span class="n">avg_pool</span><span class="p">,</span> <span class="n">max_pool</span><span class="p">])</span> <span class="n">dropout</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="mf">0.3</span><span class="p">)(</span><span class="n">concat</span><span class="p">)</span> <span class="n">output</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"softmax"</span><span class="p">)(</span><span class="n">dropout</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span> <span class="n">inputs</span><span class="o">=</span><span class="p">[</span><span class="n">input_ids</span><span class="p">,</span> <span class="n">attention_masks</span><span class="p">,</span> <span class="n">token_type_ids</span><span class="p">],</span> <span class="n">outputs</span><span class="o">=</span><span class="n">output</span> <span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">Adam</span><span class="p">(),</span> <span class="n">loss</span><span class="o">=</span><span class="s2">"categorical_crossentropy"</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="s2">"acc"</span><span class="p">],</span> <span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Strategy: </span><span class="si">{</span><span class="n">strategy</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">summary</span><span class="p">()</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>HBox(children=(FloatProgress(value=0.0, description='Downloading', max=433.0, style=ProgressStyle(description_… </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>HBox(children=(FloatProgress(value=0.0, description='Downloading', max=536063208.0, style=ProgressStyle(descri… </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Strategy: <tensorflow.python.distribute.mirrored_strategy.MirroredStrategy object at 0x7faf9dc63a90> Model: "functional_1" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_ids (InputLayer) [(None, 128)] 0 __________________________________________________________________________________________________ attention_masks (InputLayer) [(None, 128)] 0 __________________________________________________________________________________________________ token_type_ids (InputLayer) [(None, 128)] 0 __________________________________________________________________________________________________ tf_bert_model (TFBertModel) ((None, 128, 768), ( 109482240 input_ids[0][0] attention_masks[0][0] token_type_ids[0][0] __________________________________________________________________________________________________ bidirectional (Bidirectional) (None, 128, 128) 426496 tf_bert_model[0][0] __________________________________________________________________________________________________ global_average_pooling1d (Globa (None, 128) 0 bidirectional[0][0] __________________________________________________________________________________________________ global_max_pooling1d (GlobalMax (None, 128) 0 bidirectional[0][0] __________________________________________________________________________________________________ concatenate (Concatenate) (None, 256) 0 global_average_pooling1d[0][0] global_max_pooling1d[0][0] __________________________________________________________________________________________________ dropout_37 (Dropout) (None, 256) 0 concatenate[0][0] __________________________________________________________________________________________________ dense (Dense) (None, 3) 771 dropout_37[0][0] ================================================================================================== Total params: 109,909,507 Trainable params: 427,267 Non-trainable params: 109,482,240 __________________________________________________________________________________________________ </code></pre></div> </div> <p>Create train and validation data generators</p> <div class="codehilite"><pre><span></span><code><span class="n">train_data</span> <span class="o">=</span> <span class="n">BertSemanticDataGenerator</span><span class="p">(</span> <span class="n">train_df</span><span class="p">[[</span><span class="s2">"sentence1"</span><span class="p">,</span> <span class="s2">"sentence2"</span><span class="p">]]</span><span class="o">.</span><span class="n">values</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">"str"</span><span class="p">),</span> <span class="n">y_train</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="p">)</span> <span class="n">valid_data</span> <span class="o">=</span> <span class="n">BertSemanticDataGenerator</span><span class="p">(</span> <span class="n">valid_df</span><span class="p">[[</span><span class="s2">"sentence1"</span><span class="p">,</span> <span class="s2">"sentence2"</span><span class="p">]]</span><span class="o">.</span><span class="n">values</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">"str"</span><span class="p">),</span> <span class="n">y_val</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti… </code></pre></div> </div> <hr /> <h2 id="train-the-model">Train the Model</h2> <p>Training is done only for the top layers to perform "feature extraction", which will allow the model to use the representations of the pretrained model.</p> <div class="codehilite"><pre><span></span><code><span class="n">history</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span> <span class="n">train_data</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="n">valid_data</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="n">epochs</span><span class="p">,</span> <span class="n">use_multiprocessing</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">workers</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 1/2 3121/3121 [==============================] - 666s 213ms/step - loss: 0.6925 - acc: 0.7049 - val_loss: 0.5294 - val_acc: 0.7899 Epoch 2/2 3121/3121 [==============================] - 661s 212ms/step - loss: 0.5917 - acc: 0.7587 - val_loss: 0.4955 - val_acc: 0.8052 </code></pre></div> </div> <hr /> <h2 id="finetuning">Fine-tuning</h2> <p>This step must only be performed after the feature extraction model has been trained to convergence on the new data.</p> <p>This is an optional last step where <code>bert_model</code> is unfreezed and retrained with a very low learning rate. This can deliver meaningful improvement by incrementally adapting the pretrained features to the new data.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Unfreeze the bert_model.</span> <span class="n">bert_model</span><span class="o">.</span><span class="n">trainable</span> <span class="o">=</span> <span class="kc">True</span> <span class="c1"># Recompile the model to make the change effective.</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="mf">1e-5</span><span class="p">),</span> <span class="n">loss</span><span class="o">=</span><span class="s2">"categorical_crossentropy"</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="s2">"accuracy"</span><span class="p">],</span> <span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">summary</span><span class="p">()</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Model: "functional_1" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_ids (InputLayer) [(None, 128)] 0 __________________________________________________________________________________________________ attention_masks (InputLayer) [(None, 128)] 0 __________________________________________________________________________________________________ token_type_ids (InputLayer) [(None, 128)] 0 __________________________________________________________________________________________________ tf_bert_model (TFBertModel) ((None, 128, 768), ( 109482240 input_ids[0][0] attention_masks[0][0] token_type_ids[0][0] __________________________________________________________________________________________________ bidirectional (Bidirectional) (None, 128, 128) 426496 tf_bert_model[0][0] __________________________________________________________________________________________________ global_average_pooling1d (Globa (None, 128) 0 bidirectional[0][0] __________________________________________________________________________________________________ global_max_pooling1d (GlobalMax (None, 128) 0 bidirectional[0][0] __________________________________________________________________________________________________ concatenate (Concatenate) (None, 256) 0 global_average_pooling1d[0][0] global_max_pooling1d[0][0] __________________________________________________________________________________________________ dropout_37 (Dropout) (None, 256) 0 concatenate[0][0] __________________________________________________________________________________________________ dense (Dense) (None, 3) 771 dropout_37[0][0] ================================================================================================== Total params: 109,909,507 Trainable params: 109,909,507 Non-trainable params: 0 __________________________________________________________________________________________________ </code></pre></div> </div> <h2 id="train-the-entire-model-endtoend">Train the entire model end-to-end</h2> <div class="codehilite"><pre><span></span><code><span class="n">history</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span> <span class="n">train_data</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="n">valid_data</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="n">epochs</span><span class="p">,</span> <span class="n">use_multiprocessing</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">workers</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 1/2 3121/3121 [==============================] - 1574s 504ms/step - loss: 0.4698 - accuracy: 0.8181 - val_loss: 0.3787 - val_accuracy: 0.8598 Epoch 2/2 3121/3121 [==============================] - 1569s 503ms/step - loss: 0.3516 - accuracy: 0.8702 - val_loss: 0.3416 - val_accuracy: 0.8757 </code></pre></div> </div> <hr /> <h2 id="evaluate-model-on-the-test-set">Evaluate model on the test set</h2> <div class="codehilite"><pre><span></span><code><span class="n">test_data</span> <span class="o">=</span> <span class="n">BertSemanticDataGenerator</span><span class="p">(</span> <span class="n">test_df</span><span class="p">[[</span><span class="s2">"sentence1"</span><span class="p">,</span> <span class="s2">"sentence2"</span><span class="p">]]</span><span class="o">.</span><span class="n">values</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">"str"</span><span class="p">),</span> <span class="n">y_test</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">test_data</span><span class="p">,</span> <span class="n">verbose</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>312/312 [==============================] - 55s 177ms/step - loss: 0.3697 - accuracy: 0.8629 [0.3696725070476532, 0.8628805875778198] </code></pre></div> </div> <hr /> <h2 id="inference-on-custom-sentences">Inference on custom sentences</h2> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">check_similarity</span><span class="p">(</span><span class="n">sentence1</span><span class="p">,</span> <span class="n">sentence2</span><span class="p">):</span> <span class="n">sentence_pairs</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="nb">str</span><span class="p">(</span><span class="n">sentence1</span><span class="p">),</span> <span class="nb">str</span><span class="p">(</span><span class="n">sentence2</span><span class="p">)]])</span> <span class="n">test_data</span> <span class="o">=</span> <span class="n">BertSemanticDataGenerator</span><span class="p">(</span> <span class="n">sentence_pairs</span><span class="p">,</span> <span class="n">labels</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">include_targets</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="p">)</span> <span class="n">proba</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">test_data</span><span class="p">[</span><span class="mi">0</span><span class="p">])[</span><span class="mi">0</span><span class="p">]</span> <span class="n">idx</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">proba</span><span class="p">)</span> <span class="n">proba</span> <span class="o">=</span> <span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">proba</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span><span class="si">:</span><span class="s2"> .2f</span><span class="si">}</span><span class="s2">%"</span> <span class="n">pred</span> <span class="o">=</span> <span class="n">labels</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span> <span class="k">return</span> <span class="n">pred</span><span class="p">,</span> <span class="n">proba</span> </code></pre></div> <p>Check results on some example sentence pairs.</p> <div class="codehilite"><pre><span></span><code><span class="n">sentence1</span> <span class="o">=</span> <span class="s2">"Two women are observing something together."</span> <span class="n">sentence2</span> <span class="o">=</span> <span class="s2">"Two women are standing with their eyes closed."</span> <span class="n">check_similarity</span><span class="p">(</span><span class="n">sentence1</span><span class="p">,</span> <span class="n">sentence2</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>('contradiction', ' 0.91%') </code></pre></div> </div> <p>Check results on some example sentence pairs.</p> <div class="codehilite"><pre><span></span><code><span class="n">sentence1</span> <span class="o">=</span> <span class="s2">"A smiling costumed woman is holding an umbrella"</span> <span class="n">sentence2</span> <span class="o">=</span> <span class="s2">"A happy woman in a fairy costume holds an umbrella"</span> <span class="n">check_similarity</span><span class="p">(</span><span class="n">sentence1</span><span class="p">,</span> <span class="n">sentence2</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>('neutral', ' 0.88%') </code></pre></div> </div> <p>Check results on some example sentence pairs</p> <div class="codehilite"><pre><span></span><code><span class="n">sentence1</span> <span class="o">=</span> <span class="s2">"A soccer game with multiple males playing"</span> <span class="n">sentence2</span> <span class="o">=</span> <span class="s2">"Some men are playing a sport"</span> <span class="n">check_similarity</span><span class="p">(</span><span class="n">sentence1</span><span class="p">,</span> <span class="n">sentence2</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>('entailment', ' 0.94%') </code></pre></div> </div> <p>Example available on HuggingFace</p> <table> <thead> <tr> <th style="text-align: center;">Trained Model</th> <th style="text-align: center;">Demo</th> </tr> </thead> <tbody> <tr> <td style="text-align: center;"><a href="https://huggingface.co/keras-io/bert-semantic-similarity"><img alt="Generic badge" src="https://img.shields.io/badge/%F0%9F%A4%97%20Model-semantic%20similarity%20with%20bert-black.svg" /></a></td> <td style="text-align: center;"><a href="https://huggingface.co/spaces/keras-io/bert-semantic-similarity"><img alt="Generic badge" src="https://img.shields.io/badge/%F0%9F%A4%97%20Spaces-semantic%20similarity%20with%20bert-black.svg" /></a></td> </tr> </tbody> </table> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#semantic-similarity-with-bert'>Semantic Similarity with BERT</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-3'> <a href='#references'>References</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#setup'>Setup</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#configuration'>Configuration</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#load-the-data'>Load the Data</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#preprocessing'>Preprocessing</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#create-a-custom-data-generator'>Create a custom data generator</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#build-the-model'>Build the model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#train-the-model'>Train the Model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#finetuning'>Fine-tuning</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#train-the-entire-model-endtoend'>Train the entire model end-to-end</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#evaluate-model-on-the-test-set'>Evaluate model on the test set</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#inference-on-custom-sentences'>Inference on custom sentences</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>