CINXE.COM
Sentence embeddings using Siamese RoBERTa-networks
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/nlp/sentence_embeddings_with_sbert/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Sentence embeddings using Siamese RoBERTa-networks"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Sentence embeddings using Siamese RoBERTa-networks"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Sentence embeddings using Siamese RoBERTa-networks</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink active" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink2" href="/examples/nlp/text_classification_from_scratch/">Text classification from scratch</a> <a class="nav-sublink2" href="/examples/nlp/active_learning_review_classification/">Review Classification using Active Learning</a> <a class="nav-sublink2" href="/examples/nlp/fnet_classification_with_keras_hub/">Text Classification using FNet</a> <a class="nav-sublink2" href="/examples/nlp/multi_label_classification/">Large-scale multi-label text classification</a> <a class="nav-sublink2" href="/examples/nlp/text_classification_with_transformer/">Text classification with Transformer</a> <a class="nav-sublink2" href="/examples/nlp/text_classification_with_switch_transformer/">Text classification with Switch Transformer</a> <a class="nav-sublink2" href="/examples/nlp/tweet-classification-using-tfdf/">Text classification using Decision Forests and pretrained embeddings</a> <a class="nav-sublink2" href="/examples/nlp/pretrained_word_embeddings/">Using pre-trained word embeddings</a> <a class="nav-sublink2" href="/examples/nlp/bidirectional_lstm_imdb/">Bidirectional LSTM on IMDB</a> <a class="nav-sublink2" href="/examples/nlp/data_parallel_training_with_keras_hub/">Data Parallel Training with KerasHub and tf.distribute</a> <a class="nav-sublink2" href="/examples/nlp/neural_machine_translation_with_keras_hub/">English-to-Spanish translation with KerasHub</a> <a class="nav-sublink2" href="/examples/nlp/neural_machine_translation_with_transformer/">English-to-Spanish translation with a sequence-to-sequence Transformer</a> <a class="nav-sublink2" href="/examples/nlp/lstm_seq2seq/">Character-level recurrent sequence-to-sequence model</a> <a class="nav-sublink2" href="/examples/nlp/multimodal_entailment/">Multimodal entailment</a> <a class="nav-sublink2" href="/examples/nlp/ner_transformers/">Named Entity Recognition using Transformers</a> <a class="nav-sublink2" href="/examples/nlp/text_extraction_with_bert/">Text Extraction with BERT</a> <a class="nav-sublink2" href="/examples/nlp/addition_rnn/">Sequence to sequence learning for performing number addition</a> <a class="nav-sublink2" href="/examples/nlp/semantic_similarity_with_keras_hub/">Semantic Similarity with KerasHub</a> <a class="nav-sublink2" href="/examples/nlp/semantic_similarity_with_bert/">Semantic Similarity with BERT</a> <a class="nav-sublink2 active" href="/examples/nlp/sentence_embeddings_with_sbert/">Sentence embeddings using Siamese RoBERTa-networks</a> <a class="nav-sublink2" href="/examples/nlp/masked_language_modeling/">End-to-end Masked Language Modeling with BERT</a> <a class="nav-sublink2" href="/examples/nlp/abstractive_summarization_with_bart/">Abstractive Text Summarization with BART</a> <a class="nav-sublink2" href="/examples/nlp/pretraining_BERT/">Pretraining BERT with Hugging Face Transformers</a> <a class="nav-sublink2" href="/examples/nlp/parameter_efficient_finetuning_of_gpt2_with_lora/">Parameter-efficient fine-tuning of GPT-2 with LoRA</a> <a class="nav-sublink2" href="/examples/nlp/mlm_training_tpus/">Training a language model from scratch with 🤗 Transformers and TPUs</a> <a class="nav-sublink2" href="/examples/nlp/multiple_choice_task_with_transfer_learning/">MultipleChoice Task with Transfer Learning</a> <a class="nav-sublink2" href="/examples/nlp/question_answering/">Question Answering with Hugging Face Transformers</a> <a class="nav-sublink2" href="/examples/nlp/t5_hf_summarization/">Abstractive Summarization with Hugging Face Transformers</a> <a class="nav-sublink" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparameter Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> <a class="nav-link" href="/keras_cv/" role="tab" aria-selected="">KerasCV: Computer Vision Workflows</a> <a class="nav-link" href="/keras_nlp/" role="tab" aria-selected="">KerasNLP: Natural Language Workflows</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/nlp/'>Natural Language Processing</a> / Sentence embeddings using Siamese RoBERTa-networks </div> <div class='k-content'> <h1 id="sentence-embeddings-using-siamese-robertanetworks">Sentence embeddings using Siamese RoBERTa-networks</h1> <p><strong>Author:</strong> <a href="https://github.com/abuelnasr0">Mohammed Abu El-Nasr</a><br> <strong>Date created:</strong> 2023/07/14<br> <strong>Last modified:</strong> 2023/07/14<br> <strong>Description:</strong> Fine-tune a RoBERTa model to generate sentence embeddings using KerasHub.</p> <div class='example_version_banner keras_3'>ⓘ This example uses Keras 3</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/nlp/ipynb/sentence_embeddings_with_sbert.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/examples/nlp/sentence_embeddings_with_sbert.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="introduction">Introduction</h2> <p>BERT and RoBERTa can be used for semantic textual similarity tasks, where two sentences are passed to the model and the network predicts whether they are similar or not. But what if we have a large collection of sentences and want to find the most similar pairs in that collection? That will take n*(n-1)/2 inference computations, where n is the number of sentences in the collection. For example, if n = 10000, the required time will be 65 hours on a V100 GPU.</p> <p>A common method to overcome the time overhead issue is to pass one sentence to the model, then average the output of the model, or take the first token (the [CLS] token) and use them as a <a href="https://en.wikipedia.org/wiki/Sentence_embedding">sentence embedding</a>, then use a vector similarity measure like cosine similarity or Manhatten / Euclidean distance to find close sentences (semantically similar sentences). That will reduce the time to find the most similar pairs in a collection of 10,000 sentences from 65 hours to 5 seconds!</p> <p>If we use RoBERTa directly, that will yield rather bad sentence embeddings. But if we fine-tune RoBERTa using a Siamese network, that will generate semantically meaningful sentence embeddings. This will enable RoBERTa to be used for new tasks. These tasks include:</p> <ul> <li>Large-scale semantic similarity comparison.</li> <li>Clustering.</li> <li>Information retrieval via semantic search.</li> </ul> <p>In this example, we will show how to fine-tune a RoBERTa model using a Siamese network such that it will be able to produce semantically meaningful sentence embeddings and use them in a semantic search and clustering example. This method of fine-tuning was introduced in <a href="https://arxiv.org/abs/1908.10084">Sentence-BERT</a></p> <hr /> <h2 id="setup">Setup</h2> <p>Let's install and import the libraries we need. We'll be using the KerasHub library in this example.</p> <p>We will also enable <a href="https://www.tensorflow.org/guide/mixed_precision">mixed precision</a> training. This will help us reduce the training time.</p> <div class="codehilite"><pre><span></span><code><span class="err">!</span><span class="n">pip</span> <span class="n">install</span> <span class="o">-</span><span class="n">q</span> <span class="o">--</span><span class="n">upgrade</span> <span class="n">keras</span><span class="o">-</span><span class="n">hub</span> <span class="err">!</span><span class="n">pip</span> <span class="n">install</span> <span class="o">-</span><span class="n">q</span> <span class="o">--</span><span class="n">upgrade</span> <span class="n">keras</span> <span class="c1"># Upgrade to Keras 3.</span> </code></pre></div> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">os</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s2">"KERAS_BACKEND"</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"tensorflow"</span> <span class="kn">import</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="nn">keras_hub</span> <span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="nn">tf</span> <span class="kn">import</span> <span class="nn">tensorflow_datasets</span> <span class="k">as</span> <span class="nn">tfds</span> <span class="kn">import</span> <span class="nn">sklearn.cluster</span> <span class="k">as</span> <span class="nn">cluster</span> <span class="n">keras</span><span class="o">.</span><span class="n">mixed_precision</span><span class="o">.</span><span class="n">set_global_policy</span><span class="p">(</span><span class="s2">"mixed_float16"</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="finetune-the-model-using-siamese-networks">Fine-tune the model using siamese networks</h2> <p><a href="https://en.wikipedia.org/wiki/Siamese_neural_network">Siamese network</a> is a neural network architecture that contains two or more subnetworks. The subnetworks share the same weights. It is used to generate feature vectors for each input and then compare them for similarity.</p> <p>For our example, the subnetwork will be a RoBERTa model that has a pooling layer on top of it to produce the embeddings of the input sentences. These embeddings will then be compared to each other to learn to produce semantically meaningful embeddings.</p> <p>The pooling strategies used are mean, max, and CLS pooling. Mean pooling produces the best results. We will use it in our examples.</p> <h3 id="finetune-using-the-regression-objective-function">Fine-tune using the regression objective function</h3> <p>For building the siamese network with the regression objective function, the siamese network is asked to predict the cosine similarity between the embeddings of the two input sentences.</p> <p>Cosine similarity indicates the angle between the sentence embeddings. If the cosine similarity is high, that means there is a small angle between the embeddings; hence, they are semantically similar.</p> <h4 id="load-the-dataset">Load the dataset</h4> <p>We will use the STSB dataset to fine-tune the model for the regression objective. STSB consists of a collection of sentence pairs that are labelled in the range [0, 5]. 0 indicates the least semantic similarity between the two sentences, and 5 indicates the most semantic similarity between the two sentences.</p> <p>The range of the cosine similarity is [-1, 1] and it's the output of the siamese network, but the range of the labels in the dataset is [0, 5]. We need to unify the range between the cosine similarity and the dataset labels, so while preparing the dataset, we will divide the labels by 2.5 and subtract 1.</p> <div class="codehilite"><pre><span></span><code><span class="n">TRAIN_BATCH_SIZE</span> <span class="o">=</span> <span class="mi">6</span> <span class="n">VALIDATION_BATCH_SIZE</span> <span class="o">=</span> <span class="mi">8</span> <span class="n">TRAIN_NUM_BATCHES</span> <span class="o">=</span> <span class="mi">300</span> <span class="n">VALIDATION_NUM_BATCHES</span> <span class="o">=</span> <span class="mi">40</span> <span class="n">AUTOTUNE</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">experimental</span><span class="o">.</span><span class="n">AUTOTUNE</span> <span class="k">def</span> <span class="nf">change_range</span><span class="p">(</span><span class="n">x</span><span class="p">):</span> <span class="k">return</span> <span class="p">(</span><span class="n">x</span> <span class="o">/</span> <span class="mf">2.5</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span> <span class="k">def</span> <span class="nf">prepare_dataset</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">num_batches</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">):</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">map</span><span class="p">(</span> <span class="k">lambda</span> <span class="n">z</span><span class="p">:</span> <span class="p">(</span> <span class="p">[</span><span class="n">z</span><span class="p">[</span><span class="s2">"sentence1"</span><span class="p">],</span> <span class="n">z</span><span class="p">[</span><span class="s2">"sentence2"</span><span class="p">]],</span> <span class="p">[</span><span class="n">tf</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">change_range</span><span class="p">(</span><span class="n">z</span><span class="p">[</span><span class="s2">"label"</span><span class="p">]),</span> <span class="n">tf</span><span class="o">.</span><span class="n">float32</span><span class="p">)],</span> <span class="p">),</span> <span class="n">num_parallel_calls</span><span class="o">=</span><span class="n">AUTOTUNE</span><span class="p">,</span> <span class="p">)</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">take</span><span class="p">(</span><span class="n">num_batches</span><span class="p">)</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">prefetch</span><span class="p">(</span><span class="n">AUTOTUNE</span><span class="p">)</span> <span class="k">return</span> <span class="n">dataset</span> <span class="n">stsb_ds</span> <span class="o">=</span> <span class="n">tfds</span><span class="o">.</span><span class="n">load</span><span class="p">(</span> <span class="s2">"glue/stsb"</span><span class="p">,</span> <span class="p">)</span> <span class="n">stsb_train</span><span class="p">,</span> <span class="n">stsb_valid</span> <span class="o">=</span> <span class="n">stsb_ds</span><span class="p">[</span><span class="s2">"train"</span><span class="p">],</span> <span class="n">stsb_ds</span><span class="p">[</span><span class="s2">"validation"</span><span class="p">]</span> <span class="n">stsb_train</span> <span class="o">=</span> <span class="n">prepare_dataset</span><span class="p">(</span><span class="n">stsb_train</span><span class="p">,</span> <span class="n">TRAIN_NUM_BATCHES</span><span class="p">,</span> <span class="n">TRAIN_BATCH_SIZE</span><span class="p">)</span> <span class="n">stsb_valid</span> <span class="o">=</span> <span class="n">prepare_dataset</span><span class="p">(</span><span class="n">stsb_valid</span><span class="p">,</span> <span class="n">VALIDATION_NUM_BATCHES</span><span class="p">,</span> <span class="n">VALIDATION_BATCH_SIZE</span><span class="p">)</span> </code></pre></div> <p>Let's see examples from the dataset of two sentenses and their similarity.</p> <div class="codehilite"><pre><span></span><code><span class="k">for</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span> <span class="ow">in</span> <span class="n">stsb_train</span><span class="p">:</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">example</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">x</span><span class="p">):</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"sentence 1 : </span><span class="si">{</span><span class="n">example</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="si">}</span><span class="s2"> "</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"sentence 2 : </span><span class="si">{</span><span class="n">example</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="si">}</span><span class="s2"> "</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"similarity : </span><span class="si">{</span><span class="n">y</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="si">}</span><span class="s2"> </span><span class="se">\n</span><span class="s2">"</span><span class="p">)</span> <span class="k">break</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>sentence 1 : b"A young girl is sitting on Santa's lap." sentence 2 : b"A little girl is sitting on Santa's lap" similarity : [0.9200001] </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>sentence 1 : b'A women sitting at a table drinking with a basketball picture in the background.' sentence 2 : b'A woman in a sari drinks something while sitting at a table.' similarity : [0.03999996] </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>sentence 1 : b'Norway marks anniversary of massacre' sentence 2 : b"Norway Marks Anniversary of Breivik's Massacre" similarity : [0.52] </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>sentence 1 : b'US drone kills six militants in Pakistan: officials' sentence 2 : b'US missiles kill 15 in Pakistan: officials' similarity : [-0.03999996] </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>sentence 1 : b'On Tuesday, the central bank left interest rates steady, as expected, but also declared that overall risks were weighted toward weakness and warned of deflation risks.' sentence 2 : b"The central bank's policy board left rates steady for now, as widely expected, but surprised the market by declaring that overall risks were weighted toward weakness." similarity : [0.6] </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>sentence 1 : b'At one of the three sampling sites at Huntington Beach, the bacteria reading came back at 160 on June 16 and at 120 on June 23.' sentence 2 : b'The readings came back at 160 on June 16 and 120 at June 23 at one of three sampling sites at Huntington Beach.' similarity : [0.29999995] </code></pre></div> </div> <h4 id="build-the-encoder-model">Build the encoder model.</h4> <p>Now, we'll build the encoder model that will produce the sentence embeddings. It consists of:</p> <ul> <li>A preprocessor layer to tokenize and generate padding masks for the sentences.</li> <li>A backbone model that will generate the contextual representation of each token in the sentence.</li> <li>A mean pooling layer to produce the embeddings. We will use <a href="/api/layers/pooling_layers/global_average_pooling1d#globalaveragepooling1d-class"><code>keras.layers.GlobalAveragePooling1D</code></a> to apply the mean pooling to the backbone outputs. We will pass the padding mask to the layer to exclude padded tokens from being averaged.</li> <li>A normalization layer to normalize the embeddings as we are using the cosine similarity.</li> </ul> <div class="codehilite"><pre><span></span><code><span class="n">preprocessor</span> <span class="o">=</span> <span class="n">keras_hub</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">RobertaPreprocessor</span><span class="o">.</span><span class="n">from_preset</span><span class="p">(</span><span class="s2">"roberta_base_en"</span><span class="p">)</span> <span class="n">backbone</span> <span class="o">=</span> <span class="n">keras_hub</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">RobertaBackbone</span><span class="o">.</span><span class="n">from_preset</span><span class="p">(</span><span class="s2">"roberta_base_en"</span><span class="p">)</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,),</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"string"</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"sentence"</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">preprocessor</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">h</span> <span class="o">=</span> <span class="n">backbone</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">embedding</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">GlobalAveragePooling1D</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">"pooling_layer"</span><span class="p">)(</span> <span class="n">h</span><span class="p">,</span> <span class="n">x</span><span class="p">[</span><span class="s2">"padding_mask"</span><span class="p">]</span> <span class="p">)</span> <span class="n">n_embedding</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">UnitNormalization</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)(</span><span class="n">embedding</span><span class="p">)</span> <span class="n">roberta_normal_encoder</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="o">=</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="o">=</span><span class="n">n_embedding</span><span class="p">)</span> <span class="n">roberta_normal_encoder</span><span class="o">.</span><span class="n">summary</span><span class="p">()</span> </code></pre></div> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold">Model: "functional_1"</span> </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace">┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓ ┃<span style="font-weight: bold"> Layer (type) </span>┃<span style="font-weight: bold"> Output Shape </span>┃<span style="font-weight: bold"> Param # </span>┃<span style="font-weight: bold"> Connected to </span>┃ ┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩ │ sentence │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">1</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ - │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">InputLayer</span>) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ roberta_preprocess… │ [(<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">512</span>), │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ sentence[<span style="color: #00af00; text-decoration-color: #00af00">0</span>][<span style="color: #00af00; text-decoration-color: #00af00">0</span>] │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">RobertaPreprocess…</span> │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">512</span>)] │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ roberta_backbone │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">512</span>, <span style="color: #00af00; text-decoration-color: #00af00">768</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">124,05…</span> │ roberta_preprocesso… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">RobertaBackbone</span>) │ │ │ roberta_preprocesso… │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ pooling_layer │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">768</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ roberta_backbone[<span style="color: #00af00; text-decoration-color: #00af00">0</span>]… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">GlobalAveragePool…</span> │ │ │ roberta_preprocesso… │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ unit_normalization │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">768</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ pooling_layer[<span style="color: #00af00; text-decoration-color: #00af00">0</span>][<span style="color: #00af00; text-decoration-color: #00af00">0</span>] │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">UnitNormalization</span>) │ │ │ │ └─────────────────────┴───────────────────┴─────────┴──────────────────────┘ </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Total params: </span><span style="color: #00af00; text-decoration-color: #00af00">124,052,736</span> (473.22 MB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">124,052,736</span> (473.22 MB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Non-trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">0</span> (0.00 B) </pre> <h4 id="build-the-siamese-network-with-the-regression-objective-function">Build the Siamese network with the regression objective function.</h4> <p>It's described above that the Siamese network has two or more subnetworks, and for this Siamese model, we need two encoders. But we don't have two encoders; we have only one encoder, but we will pass the two sentences through it. That way, we can have two paths to get the embeddings and also shared weights between the two paths.</p> <p>After passing the two sentences to the model and getting the normalized embeddings, we will multiply the two normalized embeddings to get the cosine similarity between the two sentences.</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">RegressionSiamese</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">encoder</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,),</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"string"</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"sentences"</span><span class="p">)</span> <span class="n">sen1</span><span class="p">,</span> <span class="n">sen2</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="n">u</span> <span class="o">=</span> <span class="n">encoder</span><span class="p">(</span><span class="n">sen1</span><span class="p">)</span> <span class="n">v</span> <span class="o">=</span> <span class="n">encoder</span><span class="p">(</span><span class="n">sen2</span><span class="p">)</span> <span class="n">cosine_similarity_scores</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">u</span><span class="p">,</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">v</span><span class="p">))</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span> <span class="n">inputs</span><span class="o">=</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="o">=</span><span class="n">cosine_similarity_scores</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">,</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span> <span class="o">=</span> <span class="n">encoder</span> <span class="k">def</span> <span class="nf">get_encoder</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span> </code></pre></div> <h4 id="fit-the-model">Fit the model</h4> <p>Let's try this example before training and compare it to the output after training.</p> <div class="codehilite"><pre><span></span><code><span class="n">sentences</span> <span class="o">=</span> <span class="p">[</span> <span class="s2">"Today is a very sunny day."</span><span class="p">,</span> <span class="s2">"I am hungry, I will get my meal."</span><span class="p">,</span> <span class="s2">"The dog is eating his food."</span><span class="p">,</span> <span class="p">]</span> <span class="n">query</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"The dog is enjoying his meal."</span><span class="p">]</span> <span class="n">encoder</span> <span class="o">=</span> <span class="n">roberta_normal_encoder</span> <span class="n">sentence_embeddings</span> <span class="o">=</span> <span class="n">encoder</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">constant</span><span class="p">(</span><span class="n">sentences</span><span class="p">))</span> <span class="n">query_embedding</span> <span class="o">=</span> <span class="n">encoder</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">constant</span><span class="p">(</span><span class="n">query</span><span class="p">))</span> <span class="n">cosine_similarity_scores</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">query_embedding</span><span class="p">,</span> <span class="n">tf</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">sentence_embeddings</span><span class="p">))</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">sim</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">cosine_similarity_scores</span><span class="p">[</span><span class="mi">0</span><span class="p">]):</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"cosine similarity score between sentence </span><span class="si">{</span><span class="n">i</span><span class="o">+</span><span class="mi">1</span><span class="si">}</span><span class="s2"> and the query = </span><span class="si">{</span><span class="n">sim</span><span class="si">}</span><span class="s2"> "</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>cosine similarity score between sentence 1 and the query = 0.96630859375 cosine similarity score between sentence 2 and the query = 0.97607421875 cosine similarity score between sentence 3 and the query = 0.99365234375 </code></pre></div> </div> <p>For the training we will use <code>MeanSquaredError()</code> as loss function, and <code>Adam()</code> optimizer with learning rate = 2e-5.</p> <div class="codehilite"><pre><span></span><code><span class="n">roberta_regression_siamese</span> <span class="o">=</span> <span class="n">RegressionSiamese</span><span class="p">(</span><span class="n">roberta_normal_encoder</span><span class="p">)</span> <span class="n">roberta_regression_siamese</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">loss</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">MeanSquaredError</span><span class="p">(),</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="mf">2e-5</span><span class="p">),</span> <span class="n">jit_compile</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="p">)</span> <span class="n">roberta_regression_siamese</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">stsb_train</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="n">stsb_valid</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 300/300 ━━━━━━━━━━━━━━━━━━━━ 115s 297ms/step - loss: 0.4751 - val_loss: 0.4025 <keras.src.callbacks.history.History at 0x7f5a78392140> </code></pre></div> </div> <p>Let's try the model after training, we will notice a huge difference in the output. That means that the model after fine-tuning is capable of producing semantically meaningful embeddings. where the semantically similar sentences have a small angle between them. and semantically dissimilar sentences have a large angle between them.</p> <div class="codehilite"><pre><span></span><code><span class="n">sentences</span> <span class="o">=</span> <span class="p">[</span> <span class="s2">"Today is a very sunny day."</span><span class="p">,</span> <span class="s2">"I am hungry, I will get my meal."</span><span class="p">,</span> <span class="s2">"The dog is eating his food."</span><span class="p">,</span> <span class="p">]</span> <span class="n">query</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"The dog is enjoying his food."</span><span class="p">]</span> <span class="n">encoder</span> <span class="o">=</span> <span class="n">roberta_regression_siamese</span><span class="o">.</span><span class="n">get_encoder</span><span class="p">()</span> <span class="n">sentence_embeddings</span> <span class="o">=</span> <span class="n">encoder</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">constant</span><span class="p">(</span><span class="n">sentences</span><span class="p">))</span> <span class="n">query_embedding</span> <span class="o">=</span> <span class="n">encoder</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">constant</span><span class="p">(</span><span class="n">query</span><span class="p">))</span> <span class="n">cosine_simalarities</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">query_embedding</span><span class="p">,</span> <span class="n">tf</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">sentence_embeddings</span><span class="p">))</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">sim</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">cosine_simalarities</span><span class="p">[</span><span class="mi">0</span><span class="p">]):</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"cosine similarity between sentence </span><span class="si">{</span><span class="n">i</span><span class="o">+</span><span class="mi">1</span><span class="si">}</span><span class="s2"> and the query = </span><span class="si">{</span><span class="n">sim</span><span class="si">}</span><span class="s2"> "</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>cosine similarity between sentence 1 and the query = 0.10986328125 cosine similarity between sentence 2 and the query = 0.53466796875 cosine similarity between sentence 3 and the query = 0.83544921875 </code></pre></div> </div> <h3 id="finetune-using-the-triplet-objective-function">Fine-tune Using the triplet Objective Function</h3> <p>For the Siamese network with the triplet objective function, three sentences are passed to the Siamese network <em>anchor</em>, <em>positive</em>, and <em>negative</em> sentences. <em>anchor</em> and <em>positive</em> sentences are semantically similar, and <em>anchor</em> and <em>negative</em> sentences are semantically dissimilar. The objective is to minimize the distance between the <em>anchor</em> sentence and the <em>positive</em> sentence, and to maximize the distance between the <em>anchor</em> sentence and the <em>negative</em> sentence.</p> <h4 id="load-the-dataset">Load the dataset</h4> <p>We will use the Wikipedia-sections-triplets dataset for fine-tuning. This data set consists of sentences derived from the Wikipedia website. It has a collection of 3 sentences <em>anchor</em>, <em>positive</em>, <em>negative</em>. <em>anchor</em> and <em>positive</em> are derived from the same section. <em>anchor</em> and <em>negative</em> are derived from different sections.</p> <p>This dataset has 1.8 million training triplets and 220,000 test triplets. In this example, we will only use 1200 triplets for training and 300 for testing.</p> <div class="codehilite"><pre><span></span><code><span class="err">!</span><span class="n">wget</span> <span class="n">https</span><span class="p">:</span><span class="o">//</span><span class="n">sbert</span><span class="o">.</span><span class="n">net</span><span class="o">/</span><span class="n">datasets</span><span class="o">/</span><span class="n">wikipedia</span><span class="o">-</span><span class="n">sections</span><span class="o">-</span><span class="n">triplets</span><span class="o">.</span><span class="n">zip</span> <span class="o">-</span><span class="n">q</span> <span class="err">!</span><span class="n">unzip</span> <span class="n">wikipedia</span><span class="o">-</span><span class="n">sections</span><span class="o">-</span><span class="n">triplets</span><span class="o">.</span><span class="n">zip</span> <span class="o">-</span><span class="n">d</span> <span class="n">wikipedia</span><span class="o">-</span><span class="n">sections</span><span class="o">-</span><span class="n">triplets</span> </code></pre></div> <div class="codehilite"><pre><span></span><code><span class="n">NUM_TRAIN_BATCHES</span> <span class="o">=</span> <span class="mi">200</span> <span class="n">NUM_TEST_BATCHES</span> <span class="o">=</span> <span class="mi">75</span> <span class="n">AUTOTUNE</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">experimental</span><span class="o">.</span><span class="n">AUTOTUNE</span> <span class="k">def</span> <span class="nf">prepare_wiki_data</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">num_batches</span><span class="p">):</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">map</span><span class="p">(</span> <span class="k">lambda</span> <span class="n">z</span><span class="p">:</span> <span class="p">((</span><span class="n">z</span><span class="p">[</span><span class="s2">"Sentence1"</span><span class="p">],</span> <span class="n">z</span><span class="p">[</span><span class="s2">"Sentence2"</span><span class="p">],</span> <span class="n">z</span><span class="p">[</span><span class="s2">"Sentence3"</span><span class="p">]),</span> <span class="mi">0</span><span class="p">)</span> <span class="p">)</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="mi">6</span><span class="p">)</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">take</span><span class="p">(</span><span class="n">num_batches</span><span class="p">)</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">prefetch</span><span class="p">(</span><span class="n">AUTOTUNE</span><span class="p">)</span> <span class="k">return</span> <span class="n">dataset</span> <span class="n">wiki_train</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">experimental</span><span class="o">.</span><span class="n">make_csv_dataset</span><span class="p">(</span> <span class="s2">"wikipedia-sections-triplets/train.csv"</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">num_epochs</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="p">)</span> <span class="n">wiki_test</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">experimental</span><span class="o">.</span><span class="n">make_csv_dataset</span><span class="p">(</span> <span class="s2">"wikipedia-sections-triplets/test.csv"</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">num_epochs</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="p">)</span> <span class="n">wiki_train</span> <span class="o">=</span> <span class="n">prepare_wiki_data</span><span class="p">(</span><span class="n">wiki_train</span><span class="p">,</span> <span class="n">NUM_TRAIN_BATCHES</span><span class="p">)</span> <span class="n">wiki_test</span> <span class="o">=</span> <span class="n">prepare_wiki_data</span><span class="p">(</span><span class="n">wiki_test</span><span class="p">,</span> <span class="n">NUM_TEST_BATCHES</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Archive: wikipedia-sections-triplets.zip inflating: wikipedia-sections-triplets/validation.csv inflating: wikipedia-sections-triplets/Readme.txt inflating: wikipedia-sections-triplets/test.csv inflating: wikipedia-sections-triplets/train.csv </code></pre></div> </div> <h4 id="build-the-encoder-model">Build the encoder model</h4> <p>For this encoder model, we will use RoBERTa with mean pooling and we will not normalize the output embeddings. The encoder model consists of:</p> <ul> <li>A preprocessor layer to tokenize and generate padding masks for the sentences.</li> <li>A backbone model that will generate the contextual representation of each token in the sentence.</li> <li>A mean pooling layer to produce the embeddings.</li> </ul> <div class="codehilite"><pre><span></span><code><span class="n">preprocessor</span> <span class="o">=</span> <span class="n">keras_hub</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">RobertaPreprocessor</span><span class="o">.</span><span class="n">from_preset</span><span class="p">(</span><span class="s2">"roberta_base_en"</span><span class="p">)</span> <span class="n">backbone</span> <span class="o">=</span> <span class="n">keras_hub</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">RobertaBackbone</span><span class="o">.</span><span class="n">from_preset</span><span class="p">(</span><span class="s2">"roberta_base_en"</span><span class="p">)</span> <span class="nb">input</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,),</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"string"</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"sentence"</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">preprocessor</span><span class="p">(</span><span class="nb">input</span><span class="p">)</span> <span class="n">h</span> <span class="o">=</span> <span class="n">backbone</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">embedding</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">GlobalAveragePooling1D</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">"pooling_layer"</span><span class="p">)(</span> <span class="n">h</span><span class="p">,</span> <span class="n">x</span><span class="p">[</span><span class="s2">"padding_mask"</span><span class="p">]</span> <span class="p">)</span> <span class="n">roberta_encoder</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="o">=</span><span class="nb">input</span><span class="p">,</span> <span class="n">outputs</span><span class="o">=</span><span class="n">embedding</span><span class="p">)</span> <span class="n">roberta_encoder</span><span class="o">.</span><span class="n">summary</span><span class="p">()</span> </code></pre></div> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold">Model: "functional_3"</span> </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace">┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓ ┃<span style="font-weight: bold"> Layer (type) </span>┃<span style="font-weight: bold"> Output Shape </span>┃<span style="font-weight: bold"> Param # </span>┃<span style="font-weight: bold"> Connected to </span>┃ ┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩ │ sentence │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">1</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ - │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">InputLayer</span>) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ roberta_preprocess… │ [(<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">512</span>), │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ sentence[<span style="color: #00af00; text-decoration-color: #00af00">0</span>][<span style="color: #00af00; text-decoration-color: #00af00">0</span>] │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">RobertaPreprocess…</span> │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">512</span>)] │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ roberta_backbone_1 │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">512</span>, <span style="color: #00af00; text-decoration-color: #00af00">768</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">124,05…</span> │ roberta_preprocesso… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">RobertaBackbone</span>) │ │ │ roberta_preprocesso… │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ pooling_layer │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">768</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ roberta_backbone_1[<span style="color: #00af00; text-decoration-color: #00af00">…</span> │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">GlobalAveragePool…</span> │ │ │ roberta_preprocesso… │ └─────────────────────┴───────────────────┴─────────┴──────────────────────┘ </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Total params: </span><span style="color: #00af00; text-decoration-color: #00af00">124,052,736</span> (473.22 MB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">124,052,736</span> (473.22 MB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Non-trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">0</span> (0.00 B) </pre> <h4 id="build-the-siamese-network-with-the-triplet-objective-function">Build the Siamese network with the triplet objective function</h4> <p>For the Siamese network with the triplet objective function, we will build the model with an encoder, and we will pass the three sentences through that encoder. We will get an embedding for each sentence, and we will calculate the <code>positive_dist</code> and <code>negative_dist</code> that will be passed to the loss function described below.</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">TripletSiamese</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">encoder</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="n">anchor</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,),</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"string"</span><span class="p">)</span> <span class="n">positive</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,),</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"string"</span><span class="p">)</span> <span class="n">negative</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,),</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"string"</span><span class="p">)</span> <span class="n">ea</span> <span class="o">=</span> <span class="n">encoder</span><span class="p">(</span><span class="n">anchor</span><span class="p">)</span> <span class="n">ep</span> <span class="o">=</span> <span class="n">encoder</span><span class="p">(</span><span class="n">positive</span><span class="p">)</span> <span class="n">en</span> <span class="o">=</span> <span class="n">encoder</span><span class="p">(</span><span class="n">negative</span><span class="p">)</span> <span class="n">positive_dist</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">square</span><span class="p">(</span><span class="n">ea</span> <span class="o">-</span> <span class="n">ep</span><span class="p">),</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="n">negative_dist</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">square</span><span class="p">(</span><span class="n">ea</span> <span class="o">-</span> <span class="n">en</span><span class="p">),</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="n">positive_dist</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">positive_dist</span><span class="p">)</span> <span class="n">negative_dist</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">negative_dist</span><span class="p">)</span> <span class="n">output</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">stack</span><span class="p">([</span><span class="n">positive_dist</span><span class="p">,</span> <span class="n">negative_dist</span><span class="p">],</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">inputs</span><span class="o">=</span><span class="p">[</span><span class="n">anchor</span><span class="p">,</span> <span class="n">positive</span><span class="p">,</span> <span class="n">negative</span><span class="p">],</span> <span class="n">outputs</span><span class="o">=</span><span class="n">output</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span> <span class="o">=</span> <span class="n">encoder</span> <span class="k">def</span> <span class="nf">get_encoder</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span> </code></pre></div> <p>We will use a custom loss function for the triplet objective. The loss function will receive the distance between the <em>anchor</em> and the <em>positive</em> embeddings <code>positive_dist</code>, and the distance between the <em>anchor</em> and the <em>negative</em> embeddings <code>negative_dist</code>, where they are stacked together in <code>y_pred</code>.</p> <p>We will use <code>positive_dist</code> and <code>negative_dist</code> to compute the loss such that <code>negative_dist</code> is larger than <code>positive_dist</code> at least by a specific margin. Mathematically, we will minimize this loss function: <code>max( positive_dist - negative_dist + margin, 0)</code>.</p> <p>There is no <code>y_true</code> used in this loss function. Note that we set the labels in the dataset to zero, but they will not be used.</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">TripletLoss</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">Loss</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">margin</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">margin</span> <span class="o">=</span> <span class="n">margin</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">y_true</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">):</span> <span class="n">positive_dist</span><span class="p">,</span> <span class="n">negative_dist</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">unstack</span><span class="p">(</span><span class="n">y_pred</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="n">losses</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="n">positive_dist</span> <span class="o">-</span> <span class="n">negative_dist</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">margin</span><span class="p">)</span> <span class="k">return</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">losses</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> </code></pre></div> <h4 id="fit-the-model">Fit the model</h4> <p>For the training, we will use the custom <code>TripletLoss()</code> loss function, and <code>Adam()</code> optimizer with a learning rate = 2e-5.</p> <div class="codehilite"><pre><span></span><code><span class="n">roberta_triplet_siamese</span> <span class="o">=</span> <span class="n">TripletSiamese</span><span class="p">(</span><span class="n">roberta_encoder</span><span class="p">)</span> <span class="n">roberta_triplet_siamese</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">loss</span><span class="o">=</span><span class="n">TripletLoss</span><span class="p">(),</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="mf">2e-5</span><span class="p">),</span> <span class="n">jit_compile</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="p">)</span> <span class="n">roberta_triplet_siamese</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">wiki_train</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="n">wiki_test</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 200/200 ━━━━━━━━━━━━━━━━━━━━ 128s 467ms/step - loss: 0.7822 - val_loss: 0.7126 <keras.src.callbacks.history.History at 0x7f5c3636c580> </code></pre></div> </div> <p>Let's try this model in a clustering example. Here are 6 questions. first 3 questions about learning English, and the last 3 questions about working online. Let's see if the embeddings produced by our encoder will cluster them correctly.</p> <div class="codehilite"><pre><span></span><code><span class="n">questions</span> <span class="o">=</span> <span class="p">[</span> <span class="s2">"What should I do to improve my English writting?"</span><span class="p">,</span> <span class="s2">"How to be good at speaking English?"</span><span class="p">,</span> <span class="s2">"How can I improve my English?"</span><span class="p">,</span> <span class="s2">"How to earn money online?"</span><span class="p">,</span> <span class="s2">"How do I earn money online?"</span><span class="p">,</span> <span class="s2">"How to work and earn money through internet?"</span><span class="p">,</span> <span class="p">]</span> <span class="n">encoder</span> <span class="o">=</span> <span class="n">roberta_triplet_siamese</span><span class="o">.</span><span class="n">get_encoder</span><span class="p">()</span> <span class="n">embeddings</span> <span class="o">=</span> <span class="n">encoder</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">constant</span><span class="p">(</span><span class="n">questions</span><span class="p">))</span> <span class="n">kmeans</span> <span class="o">=</span> <span class="n">cluster</span><span class="o">.</span><span class="n">KMeans</span><span class="p">(</span><span class="n">n_clusters</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">n_init</span><span class="o">=</span><span class="s2">"auto"</span><span class="p">)</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">embeddings</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">label</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">kmeans</span><span class="o">.</span><span class="n">labels_</span><span class="p">):</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"sentence (</span><span class="si">{</span><span class="n">questions</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="si">}</span><span class="s2">) belongs to cluster </span><span class="si">{</span><span class="n">label</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>sentence (What should I do to improve my English writting?) belongs to cluster 1 sentence (How to be good at speaking English?) belongs to cluster 1 sentence (How can I improve my English?) belongs to cluster 1 sentence (How to earn money online?) belongs to cluster 0 sentence (How do I earn money online?) belongs to cluster 0 sentence (How to work and earn money through internet?) belongs to cluster 0 </code></pre></div> </div> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#sentence-embeddings-using-siamese-robertanetworks'>Sentence embeddings using Siamese RoBERTa-networks</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#setup'>Setup</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#finetune-the-model-using-siamese-networks'>Fine-tune the model using siamese networks</a> </div> <div class='k-outline-depth-3'> <a href='#finetune-using-the-regression-objective-function'>Fine-tune using the regression objective function</a> </div> <div class='k-outline-depth-3'> <a href='#finetune-using-the-triplet-objective-function'>Fine-tune Using the triplet Objective Function</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>