CINXE.COM
Semantic Similarity with KerasHub
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/nlp/semantic_similarity_with_keras_hub/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Semantic Similarity with KerasHub"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Semantic Similarity with KerasHub"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Semantic Similarity with KerasHub</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink active" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink2" href="/examples/nlp/text_classification_from_scratch/">Text classification from scratch</a> <a class="nav-sublink2" href="/examples/nlp/active_learning_review_classification/">Review Classification using Active Learning</a> <a class="nav-sublink2" href="/examples/nlp/fnet_classification_with_keras_hub/">Text Classification using FNet</a> <a class="nav-sublink2" href="/examples/nlp/multi_label_classification/">Large-scale multi-label text classification</a> <a class="nav-sublink2" href="/examples/nlp/text_classification_with_transformer/">Text classification with Transformer</a> <a class="nav-sublink2" href="/examples/nlp/text_classification_with_switch_transformer/">Text classification with Switch Transformer</a> <a class="nav-sublink2" href="/examples/nlp/tweet-classification-using-tfdf/">Text classification using Decision Forests and pretrained embeddings</a> <a class="nav-sublink2" href="/examples/nlp/pretrained_word_embeddings/">Using pre-trained word embeddings</a> <a class="nav-sublink2" href="/examples/nlp/bidirectional_lstm_imdb/">Bidirectional LSTM on IMDB</a> <a class="nav-sublink2" href="/examples/nlp/data_parallel_training_with_keras_hub/">Data Parallel Training with KerasHub and tf.distribute</a> <a class="nav-sublink2" href="/examples/nlp/neural_machine_translation_with_keras_hub/">English-to-Spanish translation with KerasHub</a> <a class="nav-sublink2" href="/examples/nlp/neural_machine_translation_with_transformer/">English-to-Spanish translation with a sequence-to-sequence Transformer</a> <a class="nav-sublink2" href="/examples/nlp/lstm_seq2seq/">Character-level recurrent sequence-to-sequence model</a> <a class="nav-sublink2" href="/examples/nlp/multimodal_entailment/">Multimodal entailment</a> <a class="nav-sublink2" href="/examples/nlp/ner_transformers/">Named Entity Recognition using Transformers</a> <a class="nav-sublink2" href="/examples/nlp/text_extraction_with_bert/">Text Extraction with BERT</a> <a class="nav-sublink2" href="/examples/nlp/addition_rnn/">Sequence to sequence learning for performing number addition</a> <a class="nav-sublink2 active" href="/examples/nlp/semantic_similarity_with_keras_hub/">Semantic Similarity with KerasHub</a> <a class="nav-sublink2" href="/examples/nlp/semantic_similarity_with_bert/">Semantic Similarity with BERT</a> <a class="nav-sublink2" href="/examples/nlp/sentence_embeddings_with_sbert/">Sentence embeddings using Siamese RoBERTa-networks</a> <a class="nav-sublink2" href="/examples/nlp/masked_language_modeling/">End-to-end Masked Language Modeling with BERT</a> <a class="nav-sublink2" href="/examples/nlp/abstractive_summarization_with_bart/">Abstractive Text Summarization with BART</a> <a class="nav-sublink2" href="/examples/nlp/pretraining_BERT/">Pretraining BERT with Hugging Face Transformers</a> <a class="nav-sublink2" href="/examples/nlp/parameter_efficient_finetuning_of_gpt2_with_lora/">Parameter-efficient fine-tuning of GPT-2 with LoRA</a> <a class="nav-sublink2" href="/examples/nlp/mlm_training_tpus/">Training a language model from scratch with 🤗 Transformers and TPUs</a> <a class="nav-sublink2" href="/examples/nlp/multiple_choice_task_with_transfer_learning/">MultipleChoice Task with Transfer Learning</a> <a class="nav-sublink2" href="/examples/nlp/question_answering/">Question Answering with Hugging Face Transformers</a> <a class="nav-sublink2" href="/examples/nlp/t5_hf_summarization/">Abstractive Summarization with Hugging Face Transformers</a> <a class="nav-sublink" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparameter Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> <a class="nav-link" href="/keras_cv/" role="tab" aria-selected="">KerasCV: Computer Vision Workflows</a> <a class="nav-link" href="/keras_nlp/" role="tab" aria-selected="">KerasNLP: Natural Language Workflows</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/nlp/'>Natural Language Processing</a> / Semantic Similarity with KerasHub </div> <div class='k-content'> <h1 id="semantic-similarity-with-kerashub">Semantic Similarity with KerasHub</h1> <p><strong>Author:</strong> <a href="https://github.com/shivance/">Anshuman Mishra</a><br> <strong>Date created:</strong> 2023/02/25<br> <strong>Last modified:</strong> 2023/02/25<br> <strong>Description:</strong> Use pretrained models from KerasHub for the Semantic Similarity Task.</p> <div class='example_version_banner keras_3'>ⓘ This example uses Keras 3</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/nlp/ipynb/semantic_similarity_with_keras_hub.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/examples/nlp/semantic_similarity_with_keras_hub.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="introduction">Introduction</h2> <p>Semantic similarity refers to the task of determining the degree of similarity between two sentences in terms of their meaning. We already saw in <a href="https://keras.io/examples/nlp/semantic_similarity_with_bert/">this</a> example how to use SNLI (Stanford Natural Language Inference) corpus to predict sentence semantic similarity with the HuggingFace Transformers library. In this tutorial we will learn how to use <a href="https://keras.io/keras_hub/">KerasHub</a>, an extension of the core Keras API, for the same task. Furthermore, we will discover how KerasHub effectively reduces boilerplate code and simplifies the process of building and utilizing models. For more information on KerasHub, please refer to <a href="https://keras.io/keras_hub/">KerasHub's official documentation</a>.</p> <p>This guide is broken down into the following parts:</p> <ol> <li><em>Setup</em>, task definition, and establishing a baseline.</li> <li><em>Establishing baseline</em> with BERT.</li> <li><em>Saving and Reloading</em> the model.</li> <li><em>Performing inference</em> with the model. 5 <em>Improving accuracy</em> with RoBERTa</li> </ol> <hr /> <h2 id="setup">Setup</h2> <p>The following guide uses <a href="https://keras.io/keras_core/">Keras Core</a> to work in any of <code>tensorflow</code>, <code>jax</code> or <code>torch</code>. Support for Keras Core is baked into KerasHub, simply change the <code>KERAS_BACKEND</code> environment variable below to change the backend you would like to use. We select the <code>jax</code> backend below, which will give us a particularly fast train step below.</p> <div class="codehilite"><pre><span></span><code><span class="err">!</span><span class="n">pip</span> <span class="n">install</span> <span class="o">-</span><span class="n">q</span> <span class="o">--</span><span class="n">upgrade</span> <span class="n">keras</span><span class="o">-</span><span class="n">hub</span> <span class="err">!</span><span class="n">pip</span> <span class="n">install</span> <span class="o">-</span><span class="n">q</span> <span class="o">--</span><span class="n">upgrade</span> <span class="n">keras</span> <span class="c1"># Upgrade to Keras 3.</span> </code></pre></div> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> <span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="nn">tf</span> <span class="kn">import</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="nn">keras_hub</span> <span class="kn">import</span> <span class="nn">tensorflow_datasets</span> <span class="k">as</span> <span class="nn">tfds</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> </code></pre></div> </div> <p>To load the SNLI dataset, we use the tensorflow-datasets library, which contains over 550,000 samples in total. However, to ensure that this example runs quickly, we use only 20% of the training samples.</p> <hr /> <h2 id="overview-of-snli-dataset">Overview of SNLI Dataset</h2> <p>Every sample in the dataset contains three components: <code>hypothesis</code>, <code>premise</code>, and <code>label</code>. epresents the original caption provided to the author of the pair, while the hypothesis refers to the hypothesis caption created by the author of the pair. The label is assigned by annotators to indicate the similarity between the two sentences.</p> <p>The dataset contains three possible similarity label values: Contradiction, Entailment, and Neutral. Contradiction represents completely dissimilar sentences, while Entailment denotes similar meaning sentences. Lastly, Neutral refers to sentences where no clear similarity or dissimilarity can be established between them.</p> <div class="codehilite"><pre><span></span><code><span class="n">snli_train</span> <span class="o">=</span> <span class="n">tfds</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s2">"snli"</span><span class="p">,</span> <span class="n">split</span><span class="o">=</span><span class="s2">"train[:20%]"</span><span class="p">)</span> <span class="n">snli_val</span> <span class="o">=</span> <span class="n">tfds</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s2">"snli"</span><span class="p">,</span> <span class="n">split</span><span class="o">=</span><span class="s2">"validation"</span><span class="p">)</span> <span class="n">snli_test</span> <span class="o">=</span> <span class="n">tfds</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s2">"snli"</span><span class="p">,</span> <span class="n">split</span><span class="o">=</span><span class="s2">"test"</span><span class="p">)</span> <span class="c1"># Here's an example of how our training samples look like, where we randomly select</span> <span class="c1"># four samples:</span> <span class="n">sample</span> <span class="o">=</span> <span class="n">snli_test</span><span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="mi">4</span><span class="p">)</span><span class="o">.</span><span class="n">take</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">get_single_element</span><span class="p">()</span> <span class="n">sample</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>{'hypothesis': <tf.Tensor: shape=(4,), dtype=string, numpy= array([b'A girl is entertaining on stage', b'A group of people posing in front of a body of water.', b"The group of people aren't inide of the building.", b'The people are taking a carriage ride.'], dtype=object)>, 'label': <tf.Tensor: shape=(4,), dtype=int64, numpy=array([0, 0, 0, 0])>, 'premise': <tf.Tensor: shape=(4,), dtype=string, numpy= array([b'A girl in a blue leotard hula hoops on a stage with balloon shapes in the background.', b'A group of people taking pictures on a walkway in front of a large body of water.', b'Many people standing outside of a place talking to each other in front of a building that has a sign that says "HI-POINTE."', b'Three people are riding a carriage pulled by four horses.'], dtype=object)>} </code></pre></div> </div> <h3 id="preprocessing">Preprocessing</h3> <p>In our dataset, we have identified that some samples have missing or incorrectly labeled data, which is denoted by a value of -1. To ensure the accuracy and reliability of our model, we simply filter out these samples from our dataset.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">filter_labels</span><span class="p">(</span><span class="n">sample</span><span class="p">):</span> <span class="k">return</span> <span class="n">sample</span><span class="p">[</span><span class="s2">"label"</span><span class="p">]</span> <span class="o">>=</span> <span class="mi">0</span> </code></pre></div> <p>Here's a utility function that splits the example into an <code>(x, y)</code> tuple that is suitable for <code>model.fit()</code>. By default, <code>keras_hub.models.BertClassifier</code> will tokenize and pack together raw strings using a <code>"[SEP]"</code> token during training. Therefore, this label splitting is all the data preparation that we need to perform.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">split_labels</span><span class="p">(</span><span class="n">sample</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="p">(</span><span class="n">sample</span><span class="p">[</span><span class="s2">"hypothesis"</span><span class="p">],</span> <span class="n">sample</span><span class="p">[</span><span class="s2">"premise"</span><span class="p">])</span> <span class="n">y</span> <span class="o">=</span> <span class="n">sample</span><span class="p">[</span><span class="s2">"label"</span><span class="p">]</span> <span class="k">return</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span> <span class="n">train_ds</span> <span class="o">=</span> <span class="p">(</span> <span class="n">snli_train</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">filter_labels</span><span class="p">)</span> <span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">split_labels</span><span class="p">,</span> <span class="n">num_parallel_calls</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">AUTOTUNE</span><span class="p">)</span> <span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="mi">16</span><span class="p">)</span> <span class="p">)</span> <span class="n">val_ds</span> <span class="o">=</span> <span class="p">(</span> <span class="n">snli_val</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">filter_labels</span><span class="p">)</span> <span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">split_labels</span><span class="p">,</span> <span class="n">num_parallel_calls</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">AUTOTUNE</span><span class="p">)</span> <span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="mi">16</span><span class="p">)</span> <span class="p">)</span> <span class="n">test_ds</span> <span class="o">=</span> <span class="p">(</span> <span class="n">snli_test</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="n">filter_labels</span><span class="p">)</span> <span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">split_labels</span><span class="p">,</span> <span class="n">num_parallel_calls</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">AUTOTUNE</span><span class="p">)</span> <span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="mi">16</span><span class="p">)</span> <span class="p">)</span> </code></pre></div> <hr /> <h2 id="establishing-baseline-with-bert">Establishing baseline with BERT.</h2> <p>We use the BERT model from KerasHub to establish a baseline for our semantic similarity task. The <code>keras_hub.models.BertClassifier</code> class attaches a classification head to the BERT Backbone, mapping the backbone outputs to a logit output suitable for a classification task. This significantly reduces the need for custom code.</p> <p>KerasHub models have built-in tokenization capabilities that handle tokenization by default based on the selected model. However, users can also use custom preprocessing techniques as per their specific needs. If we pass a tuple as input, the model will tokenize all the strings and concatenate them with a <code>"[SEP]"</code> separator.</p> <p>We use this model with pretrained weights, and we can use the <code>from_preset()</code> method to use our own preprocessor. For the SNLI dataset, we set <code>num_classes</code> to 3.</p> <div class="codehilite"><pre><span></span><code><span class="n">bert_classifier</span> <span class="o">=</span> <span class="n">keras_hub</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">BertClassifier</span><span class="o">.</span><span class="n">from_preset</span><span class="p">(</span> <span class="s2">"bert_tiny_en_uncased"</span><span class="p">,</span> <span class="n">num_classes</span><span class="o">=</span><span class="mi">3</span> <span class="p">)</span> </code></pre></div> <p>Please note that the BERT Tiny model has only 4,386,307 trainable parameters.</p> <p>KerasHub task models come with compilation defaults. We can now train the model we just instantiated by calling the <code>fit()</code> method.</p> <div class="codehilite"><pre><span></span><code><span class="n">bert_classifier</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">train_ds</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="n">val_ds</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 6867/6867 ━━━━━━━━━━━━━━━━━━━━ 61s 8ms/step - loss: 0.8732 - sparse_categorical_accuracy: 0.5864 - val_loss: 0.5900 - val_sparse_categorical_accuracy: 0.7602 <keras.src.callbacks.history.History at 0x7f4660171fc0> </code></pre></div> </div> <p>Our BERT classifier achieved an accuracy of around 76% on the validation split. Now, let's evaluate its performance on the test split.</p> <h3 id="evaluate-the-performance-of-the-trained-model-on-test-data">Evaluate the performance of the trained model on test data.</h3> <div class="codehilite"><pre><span></span><code><span class="n">bert_classifier</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">test_ds</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 614/614 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.5815 - sparse_categorical_accuracy: 0.7628 [0.5895748734474182, 0.7618078589439392] </code></pre></div> </div> <p>Our baseline BERT model achieved a similar accuracy of around 76% on the test split. Now, let's try to improve its performance by recompiling the model with a slightly higher learning rate.</p> <div class="codehilite"><pre><span></span><code><span class="n">bert_classifier</span> <span class="o">=</span> <span class="n">keras_hub</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">BertClassifier</span><span class="o">.</span><span class="n">from_preset</span><span class="p">(</span> <span class="s2">"bert_tiny_en_uncased"</span><span class="p">,</span> <span class="n">num_classes</span><span class="o">=</span><span class="mi">3</span> <span class="p">)</span> <span class="n">bert_classifier</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">loss</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">SparseCategoricalCrossentropy</span><span class="p">(</span><span class="n">from_logits</span><span class="o">=</span><span class="kc">True</span><span class="p">),</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="mf">5e-5</span><span class="p">),</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="s2">"accuracy"</span><span class="p">],</span> <span class="p">)</span> <span class="n">bert_classifier</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">train_ds</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="n">val_ds</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="n">bert_classifier</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">test_ds</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 6867/6867 ━━━━━━━━━━━━━━━━━━━━ 59s 8ms/step - accuracy: 0.6007 - loss: 0.8636 - val_accuracy: 0.7648 - val_loss: 0.5800 614/614 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - accuracy: 0.7700 - loss: 0.5692 [0.578984260559082, 0.7686278820037842] </code></pre></div> </div> <p>Just tweaking the learning rate alone was not enough to boost performance, which stayed right around 76%. Let's try again, but this time with <a href="/api/optimizers/adamw#adamw-class"><code>keras.optimizers.AdamW</code></a>, and a learning rate schedule.</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">TriangularSchedule</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">schedules</span><span class="o">.</span><span class="n">LearningRateSchedule</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Linear ramp up for `warmup` steps, then linear decay to zero at `total` steps."""</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">rate</span><span class="p">,</span> <span class="n">warmup</span><span class="p">,</span> <span class="n">total</span><span class="p">):</span> <span class="bp">self</span><span class="o">.</span><span class="n">rate</span> <span class="o">=</span> <span class="n">rate</span> <span class="bp">self</span><span class="o">.</span><span class="n">warmup</span> <span class="o">=</span> <span class="n">warmup</span> <span class="bp">self</span><span class="o">.</span><span class="n">total</span> <span class="o">=</span> <span class="n">total</span> <span class="k">def</span> <span class="nf">get_config</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="n">config</span> <span class="o">=</span> <span class="p">{</span><span class="s2">"rate"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">rate</span><span class="p">,</span> <span class="s2">"warmup"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">warmup</span><span class="p">,</span> <span class="s2">"total"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">total</span><span class="p">}</span> <span class="k">return</span> <span class="n">config</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">step</span><span class="p">):</span> <span class="n">step</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">step</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"float32"</span><span class="p">)</span> <span class="n">rate</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">rate</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"float32"</span><span class="p">)</span> <span class="n">warmup</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">warmup</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"float32"</span><span class="p">)</span> <span class="n">total</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">total</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"float32"</span><span class="p">)</span> <span class="n">warmup_rate</span> <span class="o">=</span> <span class="n">rate</span> <span class="o">*</span> <span class="n">step</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">warmup</span> <span class="n">cooldown_rate</span> <span class="o">=</span> <span class="n">rate</span> <span class="o">*</span> <span class="p">(</span><span class="n">total</span> <span class="o">-</span> <span class="n">step</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="n">total</span> <span class="o">-</span> <span class="n">warmup</span><span class="p">)</span> <span class="n">triangular_rate</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">minimum</span><span class="p">(</span><span class="n">warmup_rate</span><span class="p">,</span> <span class="n">cooldown_rate</span><span class="p">)</span> <span class="k">return</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">maximum</span><span class="p">(</span><span class="n">triangular_rate</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">)</span> <span class="n">bert_classifier</span> <span class="o">=</span> <span class="n">keras_hub</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">BertClassifier</span><span class="o">.</span><span class="n">from_preset</span><span class="p">(</span> <span class="s2">"bert_tiny_en_uncased"</span><span class="p">,</span> <span class="n">num_classes</span><span class="o">=</span><span class="mi">3</span> <span class="p">)</span> <span class="c1"># Get the total count of training batches.</span> <span class="c1"># This requires walking the dataset to filter all -1 labels.</span> <span class="n">epochs</span> <span class="o">=</span> <span class="mi">3</span> <span class="n">total_steps</span> <span class="o">=</span> <span class="nb">sum</span><span class="p">(</span><span class="mi">1</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="n">train_ds</span><span class="o">.</span><span class="n">as_numpy_iterator</span><span class="p">())</span> <span class="o">*</span> <span class="n">epochs</span> <span class="n">warmup_steps</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">total_steps</span> <span class="o">*</span> <span class="mf">0.2</span><span class="p">)</span> <span class="n">bert_classifier</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">loss</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">SparseCategoricalCrossentropy</span><span class="p">(</span><span class="n">from_logits</span><span class="o">=</span><span class="kc">True</span><span class="p">),</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">AdamW</span><span class="p">(</span> <span class="n">TriangularSchedule</span><span class="p">(</span><span class="mf">1e-4</span><span class="p">,</span> <span class="n">warmup_steps</span><span class="p">,</span> <span class="n">total_steps</span><span class="p">)</span> <span class="p">),</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="s2">"accuracy"</span><span class="p">],</span> <span class="p">)</span> <span class="n">bert_classifier</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">train_ds</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="n">val_ds</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="n">epochs</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 1/3 6867/6867 ━━━━━━━━━━━━━━━━━━━━ 59s 8ms/step - accuracy: 0.5457 - loss: 0.9317 - val_accuracy: 0.7633 - val_loss: 0.5825 Epoch 2/3 6867/6867 ━━━━━━━━━━━━━━━━━━━━ 55s 8ms/step - accuracy: 0.7291 - loss: 0.6515 - val_accuracy: 0.7809 - val_loss: 0.5399 Epoch 3/3 6867/6867 ━━━━━━━━━━━━━━━━━━━━ 55s 8ms/step - accuracy: 0.7708 - loss: 0.5695 - val_accuracy: 0.7918 - val_loss: 0.5214 <keras.src.callbacks.history.History at 0x7f45645b3370> </code></pre></div> </div> <p>Success! With the learning rate scheduler and the <code>AdamW</code> optimizer, our validation accuracy improved to around 79%.</p> <p>Now, let's evaluate our final model on the test set and see how it performs.</p> <div class="codehilite"><pre><span></span><code><span class="n">bert_classifier</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">test_ds</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 614/614 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - accuracy: 0.7956 - loss: 0.5128 [0.5245093703269958, 0.7890879511833191] </code></pre></div> </div> <p>Our Tiny BERT model achieved an accuracy of approximately 79% on the test set with the use of a learning rate scheduler. This is a significant improvement over our previous results. Fine-tuning a pretrained BERT model can be a powerful tool in natural language processing tasks, and even a small model like Tiny BERT can achieve impressive results.</p> <p>Let's save our model for now and move on to learning how to perform inference with it.</p> <hr /> <h2 id="save-and-reload-the-model">Save and Reload the model</h2> <div class="codehilite"><pre><span></span><code><span class="n">bert_classifier</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="s2">"bert_classifier.keras"</span><span class="p">)</span> <span class="n">restored_model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">load_model</span><span class="p">(</span><span class="s2">"bert_classifier.keras"</span><span class="p">)</span> <span class="n">restored_model</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">test_ds</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 614/614 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 0.5128 - sparse_categorical_accuracy: 0.7956 [0.5245093703269958, 0.7890879511833191] </code></pre></div> </div> <hr /> <h2 id="performing-inference-with-the-model">Performing inference with the model.</h2> <p>Let's see how to perform inference with KerasHub models</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Convert to Hypothesis-Premise pair, for forward pass through model</span> <span class="n">sample</span> <span class="o">=</span> <span class="p">(</span><span class="n">sample</span><span class="p">[</span><span class="s2">"hypothesis"</span><span class="p">],</span> <span class="n">sample</span><span class="p">[</span><span class="s2">"premise"</span><span class="p">])</span> <span class="n">sample</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>(<tf.Tensor: shape=(4,), dtype=string, numpy= array([b'A girl is entertaining on stage', b'A group of people posing in front of a body of water.', b"The group of people aren't inide of the building.", b'The people are taking a carriage ride.'], dtype=object)>, <tf.Tensor: shape=(4,), dtype=string, numpy= array([b'A girl in a blue leotard hula hoops on a stage with balloon shapes in the background.', b'A group of people taking pictures on a walkway in front of a large body of water.', b'Many people standing outside of a place talking to each other in front of a building that has a sign that says "HI-POINTE."', b'Three people are riding a carriage pulled by four horses.'], dtype=object)>) </code></pre></div> </div> <p>The default preprocessor in KerasHub models handles input tokenization automatically, so we don't need to perform tokenization explicitly.</p> <div class="codehilite"><pre><span></span><code><span class="n">predictions</span> <span class="o">=</span> <span class="n">bert_classifier</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">sample</span><span class="p">)</span> <span class="k">def</span> <span class="nf">softmax</span><span class="p">(</span><span class="n">x</span><span class="p">):</span> <span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="o">/</span> <span class="n">np</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">x</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="c1"># Get the class predictions with maximum probabilities</span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">softmax</span><span class="p">(</span><span class="n">predictions</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 711ms/step </code></pre></div> </div> <hr /> <h2 id="improving-accuracy-with-roberta">Improving accuracy with RoBERTa</h2> <p>Now that we have established a baseline, we can attempt to improve our results by experimenting with different models. Thanks to KerasHub, fine-tuning a RoBERTa checkpoint on the same dataset is easy with just a few lines of code.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Inittializing a RoBERTa from preset</span> <span class="n">roberta_classifier</span> <span class="o">=</span> <span class="n">keras_hub</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">RobertaClassifier</span><span class="o">.</span><span class="n">from_preset</span><span class="p">(</span> <span class="s2">"roberta_base_en"</span><span class="p">,</span> <span class="n">num_classes</span><span class="o">=</span><span class="mi">3</span> <span class="p">)</span> <span class="n">roberta_classifier</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">train_ds</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="n">val_ds</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="n">roberta_classifier</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">test_ds</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 6867/6867 ━━━━━━━━━━━━━━━━━━━━ 2049s 297ms/step - loss: 0.5509 - sparse_categorical_accuracy: 0.7740 - val_loss: 0.3292 - val_sparse_categorical_accuracy: 0.8789 614/614 ━━━━━━━━━━━━━━━━━━━━ 56s 88ms/step - loss: 0.3307 - sparse_categorical_accuracy: 0.8784 [0.33771008253097534, 0.874796450138092] </code></pre></div> </div> <p>The RoBERTa base model has significantly more trainable parameters than the BERT Tiny model, with almost 30 times as many at 124,645,635 parameters. As a result, it took approximately 1.5 hours to train on a P100 GPU. However, the performance improvement was substantial, with accuracy increasing to 88% on both the validation and test splits. With RoBERTa, we were able to fit a maximum batch size of 16 on our P100 GPU.</p> <p>Despite using a different model, the steps to perform inference with RoBERTa are the same as with BERT!</p> <div class="codehilite"><pre><span></span><code><span class="n">predictions</span> <span class="o">=</span> <span class="n">roberta_classifier</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">sample</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">math</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">predictions</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 1/1 ━━━━━━━━━━━━━━━━━━━━ 4s 4s/step [0 0 0 0] </code></pre></div> </div> <p>We hope this tutorial has been helpful in demonstrating the ease and effectiveness of using KerasHub and BERT for semantic similarity tasks.</p> <p>Throughout this tutorial, we demonstrated how to use a pretrained BERT model to establish a baseline and improve performance by training a larger RoBERTa model using just a few lines of code.</p> <p>The KerasHub toolbox provides a range of modular building blocks for preprocessing text, including pretrained state-of-the-art models and low-level Transformer Encoder layers. We believe that this makes experimenting with natural language solutions more accessible and efficient.</p> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#semantic-similarity-with-kerashub'>Semantic Similarity with KerasHub</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#setup'>Setup</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#overview-of-snli-dataset'>Overview of SNLI Dataset</a> </div> <div class='k-outline-depth-3'> <a href='#preprocessing'>Preprocessing</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#establishing-baseline-with-bert'>Establishing baseline with BERT.</a> </div> <div class='k-outline-depth-3'> <a href='#evaluate-the-performance-of-the-trained-model-on-test-data'>Evaluate the performance of the trained model on test data.</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#save-and-reload-the-model'>Save and Reload the model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#performing-inference-with-the-model'>Performing inference with the model.</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#improving-accuracy-with-roberta'>Improving accuracy with RoBERTa</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>