CINXE.COM

Multimodal entailment

<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/nlp/multimodal_entailment/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Multimodal entailment"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Multimodal entailment"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Multimodal entailment</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink active" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink2" href="/examples/nlp/text_classification_from_scratch/">Text classification from scratch</a> <a class="nav-sublink2" href="/examples/nlp/active_learning_review_classification/">Review Classification using Active Learning</a> <a class="nav-sublink2" href="/examples/nlp/fnet_classification_with_keras_hub/">Text Classification using FNet</a> <a class="nav-sublink2" href="/examples/nlp/multi_label_classification/">Large-scale multi-label text classification</a> <a class="nav-sublink2" href="/examples/nlp/text_classification_with_transformer/">Text classification with Transformer</a> <a class="nav-sublink2" href="/examples/nlp/text_classification_with_switch_transformer/">Text classification with Switch Transformer</a> <a class="nav-sublink2" href="/examples/nlp/tweet-classification-using-tfdf/">Text classification using Decision Forests and pretrained embeddings</a> <a class="nav-sublink2" href="/examples/nlp/pretrained_word_embeddings/">Using pre-trained word embeddings</a> <a class="nav-sublink2" href="/examples/nlp/bidirectional_lstm_imdb/">Bidirectional LSTM on IMDB</a> <a class="nav-sublink2" href="/examples/nlp/data_parallel_training_with_keras_hub/">Data Parallel Training with KerasHub and tf.distribute</a> <a class="nav-sublink2" href="/examples/nlp/neural_machine_translation_with_keras_hub/">English-to-Spanish translation with KerasHub</a> <a class="nav-sublink2" href="/examples/nlp/neural_machine_translation_with_transformer/">English-to-Spanish translation with a sequence-to-sequence Transformer</a> <a class="nav-sublink2" href="/examples/nlp/lstm_seq2seq/">Character-level recurrent sequence-to-sequence model</a> <a class="nav-sublink2 active" href="/examples/nlp/multimodal_entailment/">Multimodal entailment</a> <a class="nav-sublink2" href="/examples/nlp/ner_transformers/">Named Entity Recognition using Transformers</a> <a class="nav-sublink2" href="/examples/nlp/text_extraction_with_bert/">Text Extraction with BERT</a> <a class="nav-sublink2" href="/examples/nlp/addition_rnn/">Sequence to sequence learning for performing number addition</a> <a class="nav-sublink2" href="/examples/nlp/semantic_similarity_with_keras_hub/">Semantic Similarity with KerasHub</a> <a class="nav-sublink2" href="/examples/nlp/semantic_similarity_with_bert/">Semantic Similarity with BERT</a> <a class="nav-sublink2" href="/examples/nlp/sentence_embeddings_with_sbert/">Sentence embeddings using Siamese RoBERTa-networks</a> <a class="nav-sublink2" href="/examples/nlp/masked_language_modeling/">End-to-end Masked Language Modeling with BERT</a> <a class="nav-sublink2" href="/examples/nlp/abstractive_summarization_with_bart/">Abstractive Text Summarization with BART</a> <a class="nav-sublink2" href="/examples/nlp/pretraining_BERT/">Pretraining BERT with Hugging Face Transformers</a> <a class="nav-sublink2" href="/examples/nlp/parameter_efficient_finetuning_of_gpt2_with_lora/">Parameter-efficient fine-tuning of GPT-2 with LoRA</a> <a class="nav-sublink2" href="/examples/nlp/multiple_choice_task_with_transfer_learning/">MultipleChoice Task with Transfer Learning</a> <a class="nav-sublink2" href="/examples/nlp/question_answering/">Question Answering with Hugging Face Transformers</a> <a class="nav-sublink2" href="/examples/nlp/t5_hf_summarization/">Abstractive Summarization with Hugging Face Transformers</a> <a class="nav-sublink" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparam Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/nlp/'>Natural Language Processing</a> / Multimodal entailment </div> <div class='k-content'> <h1 id="multimodal-entailment">Multimodal entailment</h1> <p><strong>Author:</strong> <a href="https://twitter.com/RisingSayak">Sayak Paul</a><br> <strong>Date created:</strong> 2021/08/08<br> <strong>Last modified:</strong> 2021/08/15<br> <strong>Description:</strong> Training a multimodal model for predicting entailment.</p> <div class='example_version_banner keras_2'>ⓘ This example uses Keras 2</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/nlp/ipynb/multimodal_entailment.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/examples/nlp/multimodal_entailment.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="introduction">Introduction</h2> <p>In this example, we will build and train a model for predicting multimodal entailment. We will be using the <a href="https://github.com/google-research-datasets/recognizing-multimodal-entailment">multimodal entailment dataset</a> recently introduced by Google Research.</p> <h3 id="what-is-multimodal-entailment">What is multimodal entailment?</h3> <p>On social media platforms, to audit and moderate content we may want to find answers to the following questions in near real-time:</p> <ul> <li>Does a given piece of information contradict the other?</li> <li>Does a given piece of information imply the other?</li> </ul> <p>In NLP, this task is called analyzing <em>textual entailment</em>. However, that's only when the information comes from text content. In practice, it's often the case the information available comes not just from text content, but from a multimodal combination of text, images, audio, video, etc. <em>Multimodal entailment</em> is simply the extension of textual entailment to a variety of new input modalities.</p> <h3 id="requirements">Requirements</h3> <p>This example requires TensorFlow 2.5 or higher. In addition, TensorFlow Hub and TensorFlow Text are required for the BERT model (<a href="https://arxiv.org/abs/1810.04805">Devlin et al.</a>). These libraries can be installed using the following command:</p> <div class="codehilite"><pre><span></span><code><span class="err">!</span><span class="n">pip</span> <span class="n">install</span> <span class="o">-</span><span class="n">q</span> <span class="n">tensorflow_text</span> </code></pre></div> <hr /> <h2 id="imports">Imports</h2> <div class="codehilite"><pre><span></span><code><span class="kn">from</span> <span class="nn">sklearn.model_selection</span> <span class="kn">import</span> <span class="n">train_test_split</span> <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span> <span class="kn">import</span> <span class="nn">pandas</span> <span class="k">as</span> <span class="nn">pd</span> <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> <span class="kn">import</span> <span class="nn">os</span> <span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="nn">tf</span> <span class="kn">import</span> <span class="nn">tensorflow_hub</span> <span class="k">as</span> <span class="nn">hub</span> <span class="kn">import</span> <span class="nn">tensorflow_text</span> <span class="k">as</span> <span class="nn">text</span> <span class="kn">from</span> <span class="nn">tensorflow</span> <span class="kn">import</span> <span class="n">keras</span> </code></pre></div> <hr /> <h2 id="define-a-label-map">Define a label map</h2> <div class="codehilite"><pre><span></span><code><span class="n">label_map</span> <span class="o">=</span> <span class="p">{</span><span class="s2">&quot;Contradictory&quot;</span><span class="p">:</span> <span class="mi">0</span><span class="p">,</span> <span class="s2">&quot;Implies&quot;</span><span class="p">:</span> <span class="mi">1</span><span class="p">,</span> <span class="s2">&quot;NoEntailment&quot;</span><span class="p">:</span> <span class="mi">2</span><span class="p">}</span> </code></pre></div> <hr /> <h2 id="collect-the-dataset">Collect the dataset</h2> <p>The original dataset is available <a href="https://github.com/google-research-datasets/recognizing-multimodal-entailment">here</a>. It comes with URLs of images which are hosted on Twitter's photo storage system called the <a href="https://blog.twitter.com/engineering/en_us/a/2012/blobstore-twitter-s-in-house-photo-storage-system">Photo Blob Storage (PBS for short)</a>. We will be working with the downloaded images along with additional data that comes with the original dataset. Thanks to <a href="https://de.linkedin.com/in/nilabhraroychowdhury">Nilabhra Roy Chowdhury</a> who worked on preparing the image data.</p> <div class="codehilite"><pre><span></span><code><span class="n">image_base_path</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">get_file</span><span class="p">(</span> <span class="s2">&quot;tweet_images&quot;</span><span class="p">,</span> <span class="s2">&quot;https://github.com/sayakpaul/Multimodal-Entailment-Baseline/releases/download/v1.0.0/tweet_images.tar.gz&quot;</span><span class="p">,</span> <span class="n">untar</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="p">)</span> </code></pre></div> <hr /> <h2 id="read-the-dataset-and-apply-basic-preprocessing">Read the dataset and apply basic preprocessing</h2> <div class="codehilite"><pre><span></span><code><span class="n">df</span> <span class="o">=</span> <span class="n">pd</span><span class="o">.</span><span class="n">read_csv</span><span class="p">(</span> <span class="s2">&quot;https://github.com/sayakpaul/Multimodal-Entailment-Baseline/raw/main/csvs/tweets.csv&quot;</span> <span class="p">)</span> <span class="n">df</span><span class="o">.</span><span class="n">sample</span><span class="p">(</span><span class="mi">10</span><span class="p">)</span> </code></pre></div> <div style="overflow-x: scroll; width: 100%;"> <style scoped> .dataframe tbody tr th:only-of-type { vertical-align: middle; } <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>.dataframe tbody tr th { vertical-align: top; } .dataframe thead th { text-align: right; } </code></pre></div> </div> </style> <table border="1" class="dataframe"> <thead> <tr style="text-align: right;"> <th></th> <th>id_1</th> <th>text_1</th> <th>image_1</th> <th>id_2</th> <th>text_2</th> <th>image_2</th> <th>label</th> </tr> </thead> <tbody> <tr> <th>291</th> <td>1330800194863190016</td> <td>#KLM1167 (B738): #AMS (Amsterdam) to #HEL (Van...</td> <td>http://pbs.twimg.com/media/EnfzuZAW4AE236p.png</td> <td>1378695438480588802</td> <td>#CKK205 (B77L): #PVG (Shanghai) to #AMS (Amste...</td> <td>http://pbs.twimg.com/media/EyIcMexXEAE6gia.png</td> <td>NoEntailment</td> </tr> <tr> <th>37</th> <td>1366581728312057856</td> <td>Friends, interested all go to have a look!\n@j...</td> <td>http://pbs.twimg.com/media/EvcS1v4UcAEEXPO.jpg</td> <td>1373810535066570759</td> <td>Friends, interested all go to have a look!\n@f...</td> <td>http://pbs.twimg.com/media/ExDBZqwVIAQ4LWk.jpg</td> <td>Contradictory</td> </tr> <tr> <th>315</th> <td>1352551603258052608</td> <td>#WINk Drops I have earned today🚀\n\nToday:1/22...</td> <td>http://pbs.twimg.com/media/EsTdcLLVcAIiFKT.jpg</td> <td>1354636016234098688</td> <td>#WINk Drops I have earned today☀\n\nToday:1/28...</td> <td>http://pbs.twimg.com/media/EsyhK-qU0AgfMAH.jpg</td> <td>NoEntailment</td> </tr> <tr> <th>761</th> <td>1379795999493853189</td> <td>#buythedip Ready to FLY even HIGHER #pennysto...</td> <td>http://pbs.twimg.com/media/EyYFJCzWgAMfTrT.jpg</td> <td>1380190250144792576</td> <td>#buythedip Ready to FLY even HIGHER #pennysto...</td> <td>http://pbs.twimg.com/media/Eydrt0ZXAAMmbfv.jpg</td> <td>NoEntailment</td> </tr> <tr> <th>146</th> <td>1340185132293099523</td> <td>I know sometimes I am weird to you.\n\nBecause...</td> <td>http://pbs.twimg.com/media/EplLRriWwAAJ2AE.jpg</td> <td>1359755419883814913</td> <td>I put my sword down and get on my knees to swe...</td> <td>http://pbs.twimg.com/media/Et7SWWeWYAICK-c.jpg</td> <td>NoEntailment</td> </tr> <tr> <th>1351</th> <td>1381256604926967813</td> <td>Finally completed the skin rendering. Will sta...</td> <td>http://pbs.twimg.com/media/Eys1j7NVIAgF-YF.jpg</td> <td>1381630932092784641</td> <td>Hair rendering. Will finish the hair by tomorr...</td> <td>http://pbs.twimg.com/media/EyyKAoaUUAElm-e.jpg</td> <td>NoEntailment</td> </tr> <tr> <th>368</th> <td>1371883298805403649</td> <td>📉 $LINK Number of Receiving Addresses (7d MA) ...</td> <td>http://pbs.twimg.com/media/EwnoltOWEAAS4mG.jpg</td> <td>1373216720974979072</td> <td>📉 $LINK Number of Receiving Addresses (7d MA) ...</td> <td>http://pbs.twimg.com/media/Ew6lVGYXEAE6Ugi.jpg</td> <td>NoEntailment</td> </tr> <tr> <th>1112</th> <td>1377679115159887873</td> <td>April is National Distracted Driving Awareness...</td> <td>http://pbs.twimg.com/media/Ex5_u7UVIAARjQ2.jpg</td> <td>1379075258448281608</td> <td>April is Distracted Driving Awareness Month. ...</td> <td>http://pbs.twimg.com/media/EyN1YjpWUAMc5ak.jpg</td> <td>NoEntailment</td> </tr> <tr> <th>264</th> <td>1330727515741167619</td> <td>♥️Verse Of The Day♥️\n.\n#VerseOfTheDay #Quran...</td> <td>http://pbs.twimg.com/media/EnexnydXIAYuI11.jpg</td> <td>1332623263495819264</td> <td>♥️Verse Of The Day♥️\n.\n#VerseOfTheDay #Quran...</td> <td>http://pbs.twimg.com/media/En5ty1VXUAATALP.jpg</td> <td>NoEntailment</td> </tr> <tr> <th>865</th> <td>1377784616275296261</td> <td>No white picket fence can keep us in. #TBT 200...</td> <td>http://pbs.twimg.com/media/Ex7fzouWQAITAq8.jpg</td> <td>1380175915804672012</td> <td>Sometimes you just need to change your altitud...</td> <td>http://pbs.twimg.com/media/EydernQXIAk2g5v.jpg</td> <td>NoEntailment</td> </tr> </tbody> </table> </div> <p>The columns we are interested in are the following:</p> <ul> <li><code>text_1</code></li> <li><code>image_1</code></li> <li><code>text_2</code></li> <li><code>image_2</code></li> <li><code>label</code></li> </ul> <p>The entailment task is formulated as the following:</p> <p><strong><em>Given the pairs of (<code>text_1</code>, <code>image_1</code>) and (<code>text_2</code>, <code>image_2</code>) do they entail (or not entail or contradict) each other?</em></strong></p> <p>We have the images already downloaded. <code>image_1</code> is downloaded as <code>id1</code> as its filename and <code>image2</code> is downloaded as <code>id2</code> as its filename. In the next step, we will add two more columns to <code>df</code> - filepaths of <code>image_1</code>s and <code>image_2</code>s.</p> <div class="codehilite"><pre><span></span><code><span class="n">images_one_paths</span> <span class="o">=</span> <span class="p">[]</span> <span class="n">images_two_paths</span> <span class="o">=</span> <span class="p">[]</span> <span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">df</span><span class="p">)):</span> <span class="n">current_row</span> <span class="o">=</span> <span class="n">df</span><span class="o">.</span><span class="n">iloc</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span> <span class="n">id_1</span> <span class="o">=</span> <span class="n">current_row</span><span class="p">[</span><span class="s2">&quot;id_1&quot;</span><span class="p">]</span> <span class="n">id_2</span> <span class="o">=</span> <span class="n">current_row</span><span class="p">[</span><span class="s2">&quot;id_2&quot;</span><span class="p">]</span> <span class="n">extentsion_one</span> <span class="o">=</span> <span class="n">current_row</span><span class="p">[</span><span class="s2">&quot;image_1&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">&quot;.&quot;</span><span class="p">)[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="n">extentsion_two</span> <span class="o">=</span> <span class="n">current_row</span><span class="p">[</span><span class="s2">&quot;image_2&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">&quot;.&quot;</span><span class="p">)[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="n">image_one_path</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">image_base_path</span><span class="p">,</span> <span class="nb">str</span><span class="p">(</span><span class="n">id_1</span><span class="p">)</span> <span class="o">+</span> <span class="sa">f</span><span class="s2">&quot;.</span><span class="si">{</span><span class="n">extentsion_one</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span> <span class="n">image_two_path</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">image_base_path</span><span class="p">,</span> <span class="nb">str</span><span class="p">(</span><span class="n">id_2</span><span class="p">)</span> <span class="o">+</span> <span class="sa">f</span><span class="s2">&quot;.</span><span class="si">{</span><span class="n">extentsion_two</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span> <span class="n">images_one_paths</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">image_one_path</span><span class="p">)</span> <span class="n">images_two_paths</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">image_two_path</span><span class="p">)</span> <span class="n">df</span><span class="p">[</span><span class="s2">&quot;image_1_path&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">images_one_paths</span> <span class="n">df</span><span class="p">[</span><span class="s2">&quot;image_2_path&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">images_two_paths</span> <span class="c1"># Create another column containing the integer ids of</span> <span class="c1"># the string labels.</span> <span class="n">df</span><span class="p">[</span><span class="s2">&quot;label_idx&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">df</span><span class="p">[</span><span class="s2">&quot;label&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">label_map</span><span class="p">[</span><span class="n">x</span><span class="p">])</span> </code></pre></div> <hr /> <h2 id="dataset-visualization">Dataset visualization</h2> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">visualize</span><span class="p">(</span><span class="n">idx</span><span class="p">):</span> <span class="n">current_row</span> <span class="o">=</span> <span class="n">df</span><span class="o">.</span><span class="n">iloc</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span> <span class="n">image_1</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">imread</span><span class="p">(</span><span class="n">current_row</span><span class="p">[</span><span class="s2">&quot;image_1_path&quot;</span><span class="p">])</span> <span class="n">image_2</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">imread</span><span class="p">(</span><span class="n">current_row</span><span class="p">[</span><span class="s2">&quot;image_2_path&quot;</span><span class="p">])</span> <span class="n">text_1</span> <span class="o">=</span> <span class="n">current_row</span><span class="p">[</span><span class="s2">&quot;text_1&quot;</span><span class="p">]</span> <span class="n">text_2</span> <span class="o">=</span> <span class="n">current_row</span><span class="p">[</span><span class="s2">&quot;text_2&quot;</span><span class="p">]</span> <span class="n">label</span> <span class="o">=</span> <span class="n">current_row</span><span class="p">[</span><span class="s2">&quot;label&quot;</span><span class="p">]</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">image_1</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">&quot;off&quot;</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="s2">&quot;Image One&quot;</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">image_1</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">&quot;off&quot;</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="s2">&quot;Image Two&quot;</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Text one: </span><span class="si">{</span><span class="n">text_1</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Text two: </span><span class="si">{</span><span class="n">text_2</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Label: </span><span class="si">{</span><span class="n">label</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span> <span class="n">random_idx</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">df</span><span class="p">))</span> <span class="n">visualize</span><span class="p">(</span><span class="n">random_idx</span><span class="p">)</span> <span class="n">random_idx</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">df</span><span class="p">))</span> <span class="n">visualize</span><span class="p">(</span><span class="n">random_idx</span><span class="p">)</span> </code></pre></div> <p><img alt="png" src="/img/examples/nlp/multimodal_entailment/multimodal_entailment_14_0.png" /></p> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Text one: Friends, interested all go to have a look! @ThePartyGoddess @OurLadyAngels @BJsWholesale @Richard_Jeni @FashionLavidaG @RapaRooski @DMVTHING @DeMarcoReports @LobidaFo @DeMarcoMorgan https://t.co/cStULl7y7G Text two: Friends, interested all go to have a look! @smittyses @CYosabel @crum_7 @CrumDarrell @ElymalikU @jenloarn @SoCodiePrevost @roblowry82 @Crummy_14 @CSchmelzenbach https://t.co/IZphLTNzgl Label: Contradictory </code></pre></div> </div> <p><img alt="png" src="/img/examples/nlp/multimodal_entailment/multimodal_entailment_14_2.png" /></p> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Text one: 👟 KICK OFF @ MARDEN SPORTS COMPLEX </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>We&#39;re underway in the Round 6 opener! </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>📺: @Foxtel, @kayosports 📱: My Football Live app https://t.co/wHSpvQaoGC </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>#WLeague #ADLvMVC #AUFC #MVFC https://t.co/3Smp8KXm8W Text two: 👟 KICK OFF @ MARSDEN SPORTS COMPLEX </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>We&#39;re underway in sunny Adelaide! </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>📺: @Foxtel, @kayosports 📱: My Football Live app https://t.co/wHSpvQaoGC </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>#ADLvCBR #WLeague #AUFC #UnitedAlways https://t.co/fG1PyLQXM4 Label: NoEntailment </code></pre></div> </div> <hr /> <h2 id="traintest-split">Train/test split</h2> <p>The dataset suffers from <a href="https://developers.google.com/machine-learning/glossary#class-imbalanced-dataset">class imbalance problem</a>. We can confirm that in the following cell.</p> <div class="codehilite"><pre><span></span><code><span class="n">df</span><span class="p">[</span><span class="s2">&quot;label&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">value_counts</span><span class="p">()</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>NoEntailment 1182 Implies 109 Contradictory 109 Name: label, dtype: int64 </code></pre></div> </div> <p>To account for that we will go for a stratified split.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># 10% for test</span> <span class="n">train_df</span><span class="p">,</span> <span class="n">test_df</span> <span class="o">=</span> <span class="n">train_test_split</span><span class="p">(</span> <span class="n">df</span><span class="p">,</span> <span class="n">test_size</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">stratify</span><span class="o">=</span><span class="n">df</span><span class="p">[</span><span class="s2">&quot;label&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">values</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="mi">42</span> <span class="p">)</span> <span class="c1"># 5% for validation</span> <span class="n">train_df</span><span class="p">,</span> <span class="n">val_df</span> <span class="o">=</span> <span class="n">train_test_split</span><span class="p">(</span> <span class="n">train_df</span><span class="p">,</span> <span class="n">test_size</span><span class="o">=</span><span class="mf">0.05</span><span class="p">,</span> <span class="n">stratify</span><span class="o">=</span><span class="n">train_df</span><span class="p">[</span><span class="s2">&quot;label&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">values</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="mi">42</span> <span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Total training examples: </span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">train_df</span><span class="p">)</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Total validation examples: </span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">val_df</span><span class="p">)</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Total test examples: </span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">test_df</span><span class="p">)</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Total training examples: 1197 Total validation examples: 63 Total test examples: 140 </code></pre></div> </div> <hr /> <h2 id="data-input-pipeline">Data input pipeline</h2> <p>TensorFlow Hub provides <a href="https://www.tensorflow.org/text/tutorials/bert_glue#loading_models_from_tensorflow_hub">variety of BERT family of models</a>. Each of those models comes with a corresponding preprocessing layer. You can learn more about these models and their preprocessing layers from <a href="https://www.tensorflow.org/text/tutorials/bert_glue#loading_models_from_tensorflow_hub">this resource</a>.</p> <p>To keep the runtime of this example relatively short, we will use a smaller variant of the original BERT model.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Define TF Hub paths to the BERT encoder and its preprocessor</span> <span class="n">bert_model_path</span> <span class="o">=</span> <span class="p">(</span> <span class="s2">&quot;https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-2_H-256_A-4/1&quot;</span> <span class="p">)</span> <span class="n">bert_preprocess_path</span> <span class="o">=</span> <span class="s2">&quot;https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3&quot;</span> </code></pre></div> <p>Our text preprocessing code mostly comes from <a href="https://www.tensorflow.org/text/tutorials/bert_glue">this tutorial</a>. You are highly encouraged to check out the tutorial to learn more about the input preprocessing.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">make_bert_preprocessing_model</span><span class="p">(</span><span class="n">sentence_features</span><span class="p">,</span> <span class="n">seq_length</span><span class="o">=</span><span class="mi">128</span><span class="p">):</span> <span class="w"> </span><span class="sd">&quot;&quot;&quot;Returns Model mapping string features to BERT inputs.</span> <span class="sd"> Args:</span> <span class="sd"> sentence_features: A list with the names of string-valued features.</span> <span class="sd"> seq_length: An integer that defines the sequence length of BERT inputs.</span> <span class="sd"> Returns:</span> <span class="sd"> A Keras Model that can be called on a list or dict of string Tensors</span> <span class="sd"> (with the order or names, resp., given by sentence_features) and</span> <span class="sd"> returns a dict of tensors for input to BERT.</span> <span class="sd"> &quot;&quot;&quot;</span> <span class="n">input_segments</span> <span class="o">=</span> <span class="p">[</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">string</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="n">ft</span><span class="p">)</span> <span class="k">for</span> <span class="n">ft</span> <span class="ow">in</span> <span class="n">sentence_features</span> <span class="p">]</span> <span class="c1"># Tokenize the text to word pieces.</span> <span class="n">bert_preprocess</span> <span class="o">=</span> <span class="n">hub</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">bert_preprocess_path</span><span class="p">)</span> <span class="n">tokenizer</span> <span class="o">=</span> <span class="n">hub</span><span class="o">.</span><span class="n">KerasLayer</span><span class="p">(</span><span class="n">bert_preprocess</span><span class="o">.</span><span class="n">tokenize</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;tokenizer&quot;</span><span class="p">)</span> <span class="n">segments</span> <span class="o">=</span> <span class="p">[</span><span class="n">tokenizer</span><span class="p">(</span><span class="n">s</span><span class="p">)</span> <span class="k">for</span> <span class="n">s</span> <span class="ow">in</span> <span class="n">input_segments</span><span class="p">]</span> <span class="c1"># Optional: Trim segments in a smart way to fit seq_length.</span> <span class="c1"># Simple cases (like this example) can skip this step and let</span> <span class="c1"># the next step apply a default truncation to approximately equal lengths.</span> <span class="n">truncated_segments</span> <span class="o">=</span> <span class="n">segments</span> <span class="c1"># Pack inputs. The details (start/end token ids, dict of output tensors)</span> <span class="c1"># are model-dependent, so this gets loaded from the SavedModel.</span> <span class="n">packer</span> <span class="o">=</span> <span class="n">hub</span><span class="o">.</span><span class="n">KerasLayer</span><span class="p">(</span> <span class="n">bert_preprocess</span><span class="o">.</span><span class="n">bert_pack_inputs</span><span class="p">,</span> <span class="n">arguments</span><span class="o">=</span><span class="nb">dict</span><span class="p">(</span><span class="n">seq_length</span><span class="o">=</span><span class="n">seq_length</span><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;packer&quot;</span><span class="p">,</span> <span class="p">)</span> <span class="n">model_inputs</span> <span class="o">=</span> <span class="n">packer</span><span class="p">(</span><span class="n">truncated_segments</span><span class="p">)</span> <span class="k">return</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">input_segments</span><span class="p">,</span> <span class="n">model_inputs</span><span class="p">)</span> <span class="n">bert_preprocess_model</span> <span class="o">=</span> <span class="n">make_bert_preprocessing_model</span><span class="p">([</span><span class="s2">&quot;text_1&quot;</span><span class="p">,</span> <span class="s2">&quot;text_2&quot;</span><span class="p">])</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">plot_model</span><span class="p">(</span><span class="n">bert_preprocess_model</span><span class="p">,</span> <span class="n">show_shapes</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">show_dtype</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> </code></pre></div> <p><img alt="png" src="/img/examples/nlp/multimodal_entailment/multimodal_entailment_22_0.png" /></p> <h3 id="run-the-preprocessor-on-a-sample-input">Run the preprocessor on a sample input</h3> <div class="codehilite"><pre><span></span><code><span class="n">idx</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">train_df</span><span class="p">))</span> <span class="n">row</span> <span class="o">=</span> <span class="n">train_df</span><span class="o">.</span><span class="n">iloc</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span> <span class="n">sample_text_1</span><span class="p">,</span> <span class="n">sample_text_2</span> <span class="o">=</span> <span class="n">row</span><span class="p">[</span><span class="s2">&quot;text_1&quot;</span><span class="p">],</span> <span class="n">row</span><span class="p">[</span><span class="s2">&quot;text_2&quot;</span><span class="p">]</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Text 1: </span><span class="si">{</span><span class="n">sample_text_1</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Text 2: </span><span class="si">{</span><span class="n">sample_text_2</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span> <span class="n">test_text</span> <span class="o">=</span> <span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">sample_text_1</span><span class="p">]),</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">sample_text_2</span><span class="p">])]</span> <span class="n">text_preprocessed</span> <span class="o">=</span> <span class="n">bert_preprocess_model</span><span class="p">(</span><span class="n">test_text</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Keys : &quot;</span><span class="p">,</span> <span class="nb">list</span><span class="p">(</span><span class="n">text_preprocessed</span><span class="o">.</span><span class="n">keys</span><span class="p">()))</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Shape Word Ids : &quot;</span><span class="p">,</span> <span class="n">text_preprocessed</span><span class="p">[</span><span class="s2">&quot;input_word_ids&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Word Ids : &quot;</span><span class="p">,</span> <span class="n">text_preprocessed</span><span class="p">[</span><span class="s2">&quot;input_word_ids&quot;</span><span class="p">][</span><span class="mi">0</span><span class="p">,</span> <span class="p">:</span><span class="mi">16</span><span class="p">])</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Shape Mask : &quot;</span><span class="p">,</span> <span class="n">text_preprocessed</span><span class="p">[</span><span class="s2">&quot;input_mask&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Input Mask : &quot;</span><span class="p">,</span> <span class="n">text_preprocessed</span><span class="p">[</span><span class="s2">&quot;input_mask&quot;</span><span class="p">][</span><span class="mi">0</span><span class="p">,</span> <span class="p">:</span><span class="mi">16</span><span class="p">])</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Shape Type Ids : &quot;</span><span class="p">,</span> <span class="n">text_preprocessed</span><span class="p">[</span><span class="s2">&quot;input_type_ids&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Type Ids : &quot;</span><span class="p">,</span> <span class="n">text_preprocessed</span><span class="p">[</span><span class="s2">&quot;input_type_ids&quot;</span><span class="p">][</span><span class="mi">0</span><span class="p">,</span> <span class="p">:</span><span class="mi">16</span><span class="p">])</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Text 1: Renewables met 97% of Scotland&#39;s electricity demand in 2020!!!! https://t.co/wi5c9UFAUF https://t.co/arcuBgh0BP Text 2: Renewables met 97% of Scotland&#39;s electricity demand in 2020 https://t.co/SrhyqPnIkU https://t.co/LORgvTM7Sn Keys : [&#39;input_mask&#39;, &#39;input_word_ids&#39;, &#39;input_type_ids&#39;] Shape Word Ids : (1, 128) Word Ids : tf.Tensor( [ 101 13918 2015 2777 5989 1003 1997 3885 1005 1055 6451 5157 1999 12609 999 999], shape=(16,), dtype=int32) Shape Mask : (1, 128) Input Mask : tf.Tensor([1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1], shape=(16,), dtype=int32) Shape Type Ids : (1, 128) Type Ids : tf.Tensor([0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], shape=(16,), dtype=int32) </code></pre></div> </div> <p>We will now create <a href="https://www.tensorflow.org/api_docs/python/tf/data/Dataset"><code>tf.data.Dataset</code></a> objects from the dataframes.</p> <p>Note that the text inputs will be preprocessed as a part of the data input pipeline. But the preprocessing modules can also be a part of their corresponding BERT models. This helps reduce the training/serving skew and lets our models operate with raw text inputs. Follow <a href="https://www.tensorflow.org/text/tutorials/classify_text_with_bert">this tutorial</a> to learn more about how to incorporate the preprocessing modules directly inside the models.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">dataframe_to_dataset</span><span class="p">(</span><span class="n">dataframe</span><span class="p">):</span> <span class="n">columns</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;image_1_path&quot;</span><span class="p">,</span> <span class="s2">&quot;image_2_path&quot;</span><span class="p">,</span> <span class="s2">&quot;text_1&quot;</span><span class="p">,</span> <span class="s2">&quot;text_2&quot;</span><span class="p">,</span> <span class="s2">&quot;label_idx&quot;</span><span class="p">]</span> <span class="n">dataframe</span> <span class="o">=</span> <span class="n">dataframe</span><span class="p">[</span><span class="n">columns</span><span class="p">]</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">dataframe</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s2">&quot;label_idx&quot;</span><span class="p">)</span> <span class="n">ds</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">from_tensor_slices</span><span class="p">((</span><span class="nb">dict</span><span class="p">(</span><span class="n">dataframe</span><span class="p">),</span> <span class="n">labels</span><span class="p">))</span> <span class="n">ds</span> <span class="o">=</span> <span class="n">ds</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="n">buffer_size</span><span class="o">=</span><span class="nb">len</span><span class="p">(</span><span class="n">dataframe</span><span class="p">))</span> <span class="k">return</span> <span class="n">ds</span> </code></pre></div> <h3 id="preprocessing-utilities">Preprocessing utilities</h3> <div class="codehilite"><pre><span></span><code><span class="n">resize</span> <span class="o">=</span> <span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="mi">128</span><span class="p">)</span> <span class="n">bert_input_features</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;input_word_ids&quot;</span><span class="p">,</span> <span class="s2">&quot;input_type_ids&quot;</span><span class="p">,</span> <span class="s2">&quot;input_mask&quot;</span><span class="p">]</span> <span class="k">def</span> <span class="nf">preprocess_image</span><span class="p">(</span><span class="n">image_path</span><span class="p">):</span> <span class="n">extension</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">strings</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">image_path</span><span class="p">)[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="n">image</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">io</span><span class="o">.</span><span class="n">read_file</span><span class="p">(</span><span class="n">image_path</span><span class="p">)</span> <span class="k">if</span> <span class="n">extension</span> <span class="o">==</span> <span class="sa">b</span><span class="s2">&quot;jpg&quot;</span><span class="p">:</span> <span class="n">image</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">decode_jpeg</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span> <span class="k">else</span><span class="p">:</span> <span class="n">image</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">decode_png</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span> <span class="n">image</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">resize</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="n">resize</span><span class="p">)</span> <span class="k">return</span> <span class="n">image</span> <span class="k">def</span> <span class="nf">preprocess_text</span><span class="p">(</span><span class="n">text_1</span><span class="p">,</span> <span class="n">text_2</span><span class="p">):</span> <span class="n">text_1</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">convert_to_tensor</span><span class="p">([</span><span class="n">text_1</span><span class="p">])</span> <span class="n">text_2</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">convert_to_tensor</span><span class="p">([</span><span class="n">text_2</span><span class="p">])</span> <span class="n">output</span> <span class="o">=</span> <span class="n">bert_preprocess_model</span><span class="p">([</span><span class="n">text_1</span><span class="p">,</span> <span class="n">text_2</span><span class="p">])</span> <span class="n">output</span> <span class="o">=</span> <span class="p">{</span><span class="n">feature</span><span class="p">:</span> <span class="n">tf</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="n">output</span><span class="p">[</span><span class="n">feature</span><span class="p">])</span> <span class="k">for</span> <span class="n">feature</span> <span class="ow">in</span> <span class="n">bert_input_features</span><span class="p">}</span> <span class="k">return</span> <span class="n">output</span> <span class="k">def</span> <span class="nf">preprocess_text_and_image</span><span class="p">(</span><span class="n">sample</span><span class="p">):</span> <span class="n">image_1</span> <span class="o">=</span> <span class="n">preprocess_image</span><span class="p">(</span><span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image_1_path&quot;</span><span class="p">])</span> <span class="n">image_2</span> <span class="o">=</span> <span class="n">preprocess_image</span><span class="p">(</span><span class="n">sample</span><span class="p">[</span><span class="s2">&quot;image_2_path&quot;</span><span class="p">])</span> <span class="n">text</span> <span class="o">=</span> <span class="n">preprocess_text</span><span class="p">(</span><span class="n">sample</span><span class="p">[</span><span class="s2">&quot;text_1&quot;</span><span class="p">],</span> <span class="n">sample</span><span class="p">[</span><span class="s2">&quot;text_2&quot;</span><span class="p">])</span> <span class="k">return</span> <span class="p">{</span><span class="s2">&quot;image_1&quot;</span><span class="p">:</span> <span class="n">image_1</span><span class="p">,</span> <span class="s2">&quot;image_2&quot;</span><span class="p">:</span> <span class="n">image_2</span><span class="p">,</span> <span class="s2">&quot;text&quot;</span><span class="p">:</span> <span class="n">text</span><span class="p">}</span> </code></pre></div> <h3 id="create-the-final-datasets">Create the final datasets</h3> <div class="codehilite"><pre><span></span><code><span class="n">batch_size</span> <span class="o">=</span> <span class="mi">32</span> <span class="n">auto</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">AUTOTUNE</span> <span class="k">def</span> <span class="nf">prepare_dataset</span><span class="p">(</span><span class="n">dataframe</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span> <span class="n">ds</span> <span class="o">=</span> <span class="n">dataframe_to_dataset</span><span class="p">(</span><span class="n">dataframe</span><span class="p">)</span> <span class="k">if</span> <span class="n">training</span><span class="p">:</span> <span class="n">ds</span> <span class="o">=</span> <span class="n">ds</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">train_df</span><span class="p">))</span> <span class="n">ds</span> <span class="o">=</span> <span class="n">ds</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="p">(</span><span class="n">preprocess_text_and_image</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="n">y</span><span class="p">))</span><span class="o">.</span><span class="n">cache</span><span class="p">()</span> <span class="n">ds</span> <span class="o">=</span> <span class="n">ds</span><span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span><span class="o">.</span><span class="n">prefetch</span><span class="p">(</span><span class="n">auto</span><span class="p">)</span> <span class="k">return</span> <span class="n">ds</span> <span class="n">train_ds</span> <span class="o">=</span> <span class="n">prepare_dataset</span><span class="p">(</span><span class="n">train_df</span><span class="p">)</span> <span class="n">validation_ds</span> <span class="o">=</span> <span class="n">prepare_dataset</span><span class="p">(</span><span class="n">val_df</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span> <span class="n">test_ds</span> <span class="o">=</span> <span class="n">prepare_dataset</span><span class="p">(</span><span class="n">test_df</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="model-building-utilities">Model building utilities</h2> <p>Our final model will accept two images along with their text counterparts. While the images will be directly fed to the model the text inputs will first be preprocessed and then will make it into the model. Below is a visual illustration of this approach:</p> <p><img alt="" src="https://github.com/sayakpaul/Multimodal-Entailment-Baseline/raw/main/figures/brief_architecture.png" /></p> <p>The model consists of the following elements:</p> <ul> <li>A standalone encoder for the images. We will use a <a href="https://arxiv.org/abs/1603.05027">ResNet50V2</a> pre-trained on the ImageNet-1k dataset for this.</li> <li>A standalone encoder for the images. A pre-trained BERT will be used for this.</li> </ul> <p>After extracting the individual embeddings, they will be projected in an identical space. Finally, their projections will be concatenated and be fed to the final classification layer.</p> <p>This is a multi-class classification problem involving the following classes:</p> <ul> <li>NoEntailment</li> <li>Implies</li> <li>Contradictory</li> </ul> <p><code>project_embeddings()</code>, <code>create_vision_encoder()</code>, and <code>create_text_encoder()</code> utilities are referred from <a href="https://keras.io/examples/nlp/nl_image_search/">this example</a>.</p> <p>Projection utilities</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">project_embeddings</span><span class="p">(</span> <span class="n">embeddings</span><span class="p">,</span> <span class="n">num_projection_layers</span><span class="p">,</span> <span class="n">projection_dims</span><span class="p">,</span> <span class="n">dropout_rate</span> <span class="p">):</span> <span class="n">projected_embeddings</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">units</span><span class="o">=</span><span class="n">projection_dims</span><span class="p">)(</span><span class="n">embeddings</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_projection_layers</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">gelu</span><span class="p">(</span><span class="n">projected_embeddings</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">projection_dims</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout_rate</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Add</span><span class="p">()([</span><span class="n">projected_embeddings</span><span class="p">,</span> <span class="n">x</span><span class="p">])</span> <span class="n">projected_embeddings</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">LayerNormalization</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">projected_embeddings</span> </code></pre></div> <p>Vision encoder utilities</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">create_vision_encoder</span><span class="p">(</span> <span class="n">num_projection_layers</span><span class="p">,</span> <span class="n">projection_dims</span><span class="p">,</span> <span class="n">dropout_rate</span><span class="p">,</span> <span class="n">trainable</span><span class="o">=</span><span class="kc">False</span> <span class="p">):</span> <span class="c1"># Load the pre-trained ResNet50V2 model to be used as the base encoder.</span> <span class="n">resnet_v2</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">applications</span><span class="o">.</span><span class="n">ResNet50V2</span><span class="p">(</span> <span class="n">include_top</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">weights</span><span class="o">=</span><span class="s2">&quot;imagenet&quot;</span><span class="p">,</span> <span class="n">pooling</span><span class="o">=</span><span class="s2">&quot;avg&quot;</span> <span class="p">)</span> <span class="c1"># Set the trainability of the base encoder.</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="n">resnet_v2</span><span class="o">.</span><span class="n">layers</span><span class="p">:</span> <span class="n">layer</span><span class="o">.</span><span class="n">trainable</span> <span class="o">=</span> <span class="n">trainable</span> <span class="c1"># Receive the images as inputs.</span> <span class="n">image_1</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;image_1&quot;</span><span class="p">)</span> <span class="n">image_2</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;image_2&quot;</span><span class="p">)</span> <span class="c1"># Preprocess the input image.</span> <span class="n">preprocessed_1</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">applications</span><span class="o">.</span><span class="n">resnet_v2</span><span class="o">.</span><span class="n">preprocess_input</span><span class="p">(</span><span class="n">image_1</span><span class="p">)</span> <span class="n">preprocessed_2</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">applications</span><span class="o">.</span><span class="n">resnet_v2</span><span class="o">.</span><span class="n">preprocess_input</span><span class="p">(</span><span class="n">image_2</span><span class="p">)</span> <span class="c1"># Generate the embeddings for the images using the resnet_v2 model</span> <span class="c1"># concatenate them.</span> <span class="n">embeddings_1</span> <span class="o">=</span> <span class="n">resnet_v2</span><span class="p">(</span><span class="n">preprocessed_1</span><span class="p">)</span> <span class="n">embeddings_2</span> <span class="o">=</span> <span class="n">resnet_v2</span><span class="p">(</span><span class="n">preprocessed_2</span><span class="p">)</span> <span class="n">embeddings</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Concatenate</span><span class="p">()([</span><span class="n">embeddings_1</span><span class="p">,</span> <span class="n">embeddings_2</span><span class="p">])</span> <span class="c1"># Project the embeddings produced by the model.</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">project_embeddings</span><span class="p">(</span> <span class="n">embeddings</span><span class="p">,</span> <span class="n">num_projection_layers</span><span class="p">,</span> <span class="n">projection_dims</span><span class="p">,</span> <span class="n">dropout_rate</span> <span class="p">)</span> <span class="c1"># Create the vision encoder model.</span> <span class="k">return</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">([</span><span class="n">image_1</span><span class="p">,</span> <span class="n">image_2</span><span class="p">],</span> <span class="n">outputs</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;vision_encoder&quot;</span><span class="p">)</span> </code></pre></div> <p>Text encoder utilities</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">create_text_encoder</span><span class="p">(</span> <span class="n">num_projection_layers</span><span class="p">,</span> <span class="n">projection_dims</span><span class="p">,</span> <span class="n">dropout_rate</span><span class="p">,</span> <span class="n">trainable</span><span class="o">=</span><span class="kc">False</span> <span class="p">):</span> <span class="c1"># Load the pre-trained BERT model to be used as the base encoder.</span> <span class="n">bert</span> <span class="o">=</span> <span class="n">hub</span><span class="o">.</span><span class="n">KerasLayer</span><span class="p">(</span><span class="n">bert_model_path</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;bert&quot;</span><span class="p">,)</span> <span class="c1"># Set the trainability of the base encoder.</span> <span class="n">bert</span><span class="o">.</span><span class="n">trainable</span> <span class="o">=</span> <span class="n">trainable</span> <span class="c1"># Receive the text as inputs.</span> <span class="n">bert_input_features</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;input_type_ids&quot;</span><span class="p">,</span> <span class="s2">&quot;input_mask&quot;</span><span class="p">,</span> <span class="s2">&quot;input_word_ids&quot;</span><span class="p">]</span> <span class="n">inputs</span> <span class="o">=</span> <span class="p">{</span> <span class="n">feature</span><span class="p">:</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">128</span><span class="p">,),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="n">feature</span><span class="p">)</span> <span class="k">for</span> <span class="n">feature</span> <span class="ow">in</span> <span class="n">bert_input_features</span> <span class="p">}</span> <span class="c1"># Generate embeddings for the preprocessed text using the BERT model.</span> <span class="n">embeddings</span> <span class="o">=</span> <span class="n">bert</span><span class="p">(</span><span class="n">inputs</span><span class="p">)[</span><span class="s2">&quot;pooled_output&quot;</span><span class="p">]</span> <span class="c1"># Project the embeddings produced by the model.</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">project_embeddings</span><span class="p">(</span> <span class="n">embeddings</span><span class="p">,</span> <span class="n">num_projection_layers</span><span class="p">,</span> <span class="n">projection_dims</span><span class="p">,</span> <span class="n">dropout_rate</span> <span class="p">)</span> <span class="c1"># Create the text encoder model.</span> <span class="k">return</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;text_encoder&quot;</span><span class="p">)</span> </code></pre></div> <p>Multimodal model utilities</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">create_multimodal_model</span><span class="p">(</span> <span class="n">num_projection_layers</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">projection_dims</span><span class="o">=</span><span class="mi">256</span><span class="p">,</span> <span class="n">dropout_rate</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">vision_trainable</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">text_trainable</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="p">):</span> <span class="c1"># Receive the images as inputs.</span> <span class="n">image_1</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;image_1&quot;</span><span class="p">)</span> <span class="n">image_2</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;image_2&quot;</span><span class="p">)</span> <span class="c1"># Receive the text as inputs.</span> <span class="n">bert_input_features</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;input_type_ids&quot;</span><span class="p">,</span> <span class="s2">&quot;input_mask&quot;</span><span class="p">,</span> <span class="s2">&quot;input_word_ids&quot;</span><span class="p">]</span> <span class="n">text_inputs</span> <span class="o">=</span> <span class="p">{</span> <span class="n">feature</span><span class="p">:</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">128</span><span class="p">,),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="n">feature</span><span class="p">)</span> <span class="k">for</span> <span class="n">feature</span> <span class="ow">in</span> <span class="n">bert_input_features</span> <span class="p">}</span> <span class="c1"># Create the encoders.</span> <span class="n">vision_encoder</span> <span class="o">=</span> <span class="n">create_vision_encoder</span><span class="p">(</span> <span class="n">num_projection_layers</span><span class="p">,</span> <span class="n">projection_dims</span><span class="p">,</span> <span class="n">dropout_rate</span><span class="p">,</span> <span class="n">vision_trainable</span> <span class="p">)</span> <span class="n">text_encoder</span> <span class="o">=</span> <span class="n">create_text_encoder</span><span class="p">(</span> <span class="n">num_projection_layers</span><span class="p">,</span> <span class="n">projection_dims</span><span class="p">,</span> <span class="n">dropout_rate</span><span class="p">,</span> <span class="n">text_trainable</span> <span class="p">)</span> <span class="c1"># Fetch the embedding projections.</span> <span class="n">vision_projections</span> <span class="o">=</span> <span class="n">vision_encoder</span><span class="p">([</span><span class="n">image_1</span><span class="p">,</span> <span class="n">image_2</span><span class="p">])</span> <span class="n">text_projections</span> <span class="o">=</span> <span class="n">text_encoder</span><span class="p">(</span><span class="n">text_inputs</span><span class="p">)</span> <span class="c1"># Concatenate the projections and pass through the classification layer.</span> <span class="n">concatenated</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Concatenate</span><span class="p">()([</span><span class="n">vision_projections</span><span class="p">,</span> <span class="n">text_projections</span><span class="p">])</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">&quot;softmax&quot;</span><span class="p">)(</span><span class="n">concatenated</span><span class="p">)</span> <span class="k">return</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">([</span><span class="n">image_1</span><span class="p">,</span> <span class="n">image_2</span><span class="p">,</span> <span class="n">text_inputs</span><span class="p">],</span> <span class="n">outputs</span><span class="p">)</span> <span class="n">multimodal_model</span> <span class="o">=</span> <span class="n">create_multimodal_model</span><span class="p">()</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">plot_model</span><span class="p">(</span><span class="n">multimodal_model</span><span class="p">,</span> <span class="n">show_shapes</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> </code></pre></div> <p><img alt="png" src="/img/examples/nlp/multimodal_entailment/multimodal_entailment_39_0.png" /></p> <p>You can inspect the structure of the individual encoders as well by setting the <code>expand_nested</code> argument of <code>plot_model()</code> to <code>True</code>. You are encouraged to play with the different hyperparameters involved in building this model and observe how the final performance is affected.</p> <hr /> <h2 id="compile-and-train-the-model">Compile and train the model</h2> <div class="codehilite"><pre><span></span><code><span class="n">multimodal_model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">optimizer</span><span class="o">=</span><span class="s2">&quot;adam&quot;</span><span class="p">,</span> <span class="n">loss</span><span class="o">=</span><span class="s2">&quot;sparse_categorical_crossentropy&quot;</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="s2">&quot;accuracy&quot;</span> <span class="p">)</span> <span class="n">history</span> <span class="o">=</span> <span class="n">multimodal_model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">train_ds</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="n">validation_ds</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 1/10 38/38 [==============================] - 49s 789ms/step - loss: 1.0014 - accuracy: 0.8229 - val_loss: 0.5514 - val_accuracy: 0.8571 Epoch 2/10 38/38 [==============================] - 3s 90ms/step - loss: 0.4019 - accuracy: 0.8814 - val_loss: 0.5866 - val_accuracy: 0.8571 Epoch 3/10 38/38 [==============================] - 3s 90ms/step - loss: 0.3557 - accuracy: 0.8897 - val_loss: 0.5929 - val_accuracy: 0.8571 Epoch 4/10 38/38 [==============================] - 3s 91ms/step - loss: 0.2877 - accuracy: 0.9006 - val_loss: 0.6272 - val_accuracy: 0.8571 Epoch 5/10 38/38 [==============================] - 3s 91ms/step - loss: 0.1796 - accuracy: 0.9398 - val_loss: 0.8545 - val_accuracy: 0.8254 Epoch 6/10 38/38 [==============================] - 3s 91ms/step - loss: 0.1292 - accuracy: 0.9566 - val_loss: 1.2276 - val_accuracy: 0.8413 Epoch 7/10 38/38 [==============================] - 3s 91ms/step - loss: 0.1015 - accuracy: 0.9666 - val_loss: 1.2914 - val_accuracy: 0.7778 Epoch 8/10 38/38 [==============================] - 3s 92ms/step - loss: 0.1253 - accuracy: 0.9524 - val_loss: 1.1944 - val_accuracy: 0.8413 Epoch 9/10 38/38 [==============================] - 3s 92ms/step - loss: 0.3064 - accuracy: 0.9131 - val_loss: 1.2162 - val_accuracy: 0.8095 Epoch 10/10 38/38 [==============================] - 3s 92ms/step - loss: 0.2212 - accuracy: 0.9248 - val_loss: 1.1080 - val_accuracy: 0.8413 </code></pre></div> </div> <hr /> <h2 id="evaluate-the-model">Evaluate the model</h2> <div class="codehilite"><pre><span></span><code><span class="n">_</span><span class="p">,</span> <span class="n">acc</span> <span class="o">=</span> <span class="n">multimodal_model</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">test_ds</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Accuracy on the test set: </span><span class="si">{</span><span class="nb">round</span><span class="p">(</span><span class="n">acc</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="mi">100</span><span class="p">,</span><span class="w"> </span><span class="mi">2</span><span class="p">)</span><span class="si">}</span><span class="s2">%.&quot;</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>5/5 [==============================] - 6s 1s/step - loss: 0.8390 - accuracy: 0.8429 Accuracy on the test set: 84.29%. </code></pre></div> </div> <hr /> <h2 id="additional-notes-regarding-training">Additional notes regarding training</h2> <p><strong>Incorporating regularization</strong>:</p> <p>The training logs suggest that the model is starting to overfit and may have benefitted from regularization. Dropout (<a href="https://jmlr.org/papers/v15/srivastava14a.html">Srivastava et al.</a>) is a simple yet powerful regularization technique that we can use in our model. But how should we apply it here?</p> <p>We could always introduce Dropout (<a href="/api/layers/regularization_layers/dropout#dropout-class"><code>keras.layers.Dropout</code></a>) in between different layers of the model. But here is another recipe. Our model expects inputs from two different data modalities. What if either of the modalities is not present during inference? To account for this, we can introduce Dropout to the individual projections just before they get concatenated:</p> <div class="codehilite"><pre><span></span><code><span class="n">vision_projections</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">rate</span><span class="p">)(</span><span class="n">vision_projections</span><span class="p">)</span> <span class="n">text_projections</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">rate</span><span class="p">)(</span><span class="n">text_projections</span><span class="p">)</span> <span class="n">concatenated</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Concatenate</span><span class="p">()([</span><span class="n">vision_projections</span><span class="p">,</span> <span class="n">text_projections</span><span class="p">])</span> </code></pre></div> <p><strong>Attending to what matters</strong>:</p> <p>Do all parts of the images correspond equally to their textual counterparts? It's likely not the case. To make our model only focus on the most important bits of the images that relate well to their corresponding textual parts we can use "cross-attention":</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Embeddings.</span> <span class="n">vision_projections</span> <span class="o">=</span> <span class="n">vision_encoder</span><span class="p">([</span><span class="n">image_1</span><span class="p">,</span> <span class="n">image_2</span><span class="p">])</span> <span class="n">text_projections</span> <span class="o">=</span> <span class="n">text_encoder</span><span class="p">(</span><span class="n">text_inputs</span><span class="p">)</span> <span class="c1"># Cross-attention (Luong-style).</span> <span class="n">query_value_attention_seq</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Attention</span><span class="p">(</span><span class="n">use_scale</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">dropout</span><span class="o">=</span><span class="mf">0.2</span><span class="p">)(</span> <span class="p">[</span><span class="n">vision_projections</span><span class="p">,</span> <span class="n">text_projections</span><span class="p">]</span> <span class="p">)</span> <span class="c1"># Concatenate.</span> <span class="n">concatenated</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Concatenate</span><span class="p">()([</span><span class="n">vision_projections</span><span class="p">,</span> <span class="n">text_projections</span><span class="p">])</span> <span class="n">contextual</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Concatenate</span><span class="p">()([</span><span class="n">concatenated</span><span class="p">,</span> <span class="n">query_value_attention_seq</span><span class="p">])</span> </code></pre></div> <p>To see this in action, refer to <a href="https://github.com/sayakpaul/Multimodal-Entailment-Baseline/blob/main/multimodal_entailment_attn.ipynb">this notebook</a>.</p> <p><strong>Handling class imbalance</strong>:</p> <p>The dataset suffers from class imbalance. Investigating the confusion matrix of the above model reveals that it performs poorly on the minority classes. If we had used a weighted loss then the training would have been more guided. You can check out <a href="https://github.com/sayakpaul/Multimodal-Entailment-Baseline/blob/main/multimodal_entailment.ipynb">this notebook</a> that takes class-imbalance into account during model training.</p> <p><strong>Using only text inputs</strong>:</p> <p>Also, what if we had only incorporated text inputs for the entailment task? Because of the nature of the text inputs encountered on social media platforms, text inputs alone would have hurt the final performance. Under a similar training setup, by only using text inputs we get to 67.14% top-1 accuracy on the same test set. Refer to <a href="https://github.com/sayakpaul/Multimodal-Entailment-Baseline/blob/main/text_entailment.ipynb">this notebook</a> for details.</p> <p>Finally, here is a table comparing different approaches taken for the entailment task:</p> <table> <thead> <tr> <th style="text-align: center;">Type</th> <th style="text-align: center;">Standard<br>Cross-entropy</th> <th style="text-align: center;">Loss-weighted<br>Cross-entropy</th> <th style="text-align: center;">Focal Loss</th> </tr> </thead> <tbody> <tr> <td style="text-align: center;">Multimodal</td> <td style="text-align: center;">77.86%</td> <td style="text-align: center;">67.86%</td> <td style="text-align: center;">86.43%</td> </tr> <tr> <td style="text-align: center;">Only text</td> <td style="text-align: center;">67.14%</td> <td style="text-align: center;">11.43%</td> <td style="text-align: center;">37.86%</td> </tr> </tbody> </table> <p>You can check out <a href="https://git.io/JR0HU">this repository</a> to learn more about how the experiments were conducted to obtain these numbers.</p> <hr /> <h2 id="final-remarks">Final remarks</h2> <ul> <li>The architecture we used in this example is too large for the number of data points available for training. It's going to benefit from more data.</li> <li>We used a smaller variant of the original BERT model. Chances are high that with a larger variant, this performance will be improved. TensorFlow Hub <a href="https://www.tensorflow.org/text/tutorials/bert_glue#loading_models_from_tensorflow_hub">provides</a> a number of different BERT models that you can experiment with.</li> <li>We kept the pre-trained models frozen. Fine-tuning them on the multimodal entailment task would could resulted in better performance.</li> <li>We built a simple baseline model for the multimodal entailment task. There are various approaches that have been proposed to tackle the entailment problem. <a href="https://docs.google.com/presentation/d/1mAB31BCmqzfedreNZYn4hsKPFmgHA9Kxz219DzyRY3c/edit?usp=sharing">This presentation deck</a> from the <a href="https://multimodal-entailment.github.io/">Recognizing Multimodal Entailment</a> tutorial provides a comprehensive overview.</li> </ul> <p>You can use the trained model hosted on <a href="https://huggingface.co/keras-io/multimodal-entailment">Hugging Face Hub</a> and try the demo on <a href="https://huggingface.co/spaces/keras-io/multimodal_entailment">Hugging Face Spaces</a></p> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#multimodal-entailment'>Multimodal entailment</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-3'> <a href='#what-is-multimodal-entailment'>What is multimodal entailment?</a> </div> <div class='k-outline-depth-3'> <a href='#requirements'>Requirements</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#imports'>Imports</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#define-a-label-map'>Define a label map</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#collect-the-dataset'>Collect the dataset</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#read-the-dataset-and-apply-basic-preprocessing'>Read the dataset and apply basic preprocessing</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#dataset-visualization'>Dataset visualization</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#traintest-split'>Train/test split</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#data-input-pipeline'>Data input pipeline</a> </div> <div class='k-outline-depth-3'> <a href='#run-the-preprocessor-on-a-sample-input'>Run the preprocessor on a sample input</a> </div> <div class='k-outline-depth-3'> <a href='#preprocessing-utilities'>Preprocessing utilities</a> </div> <div class='k-outline-depth-3'> <a href='#create-the-final-datasets'>Create the final datasets</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#model-building-utilities'>Model building utilities</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#compile-and-train-the-model'>Compile and train the model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#evaluate-the-model'>Evaluate the model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#additional-notes-regarding-training'>Additional notes regarding training</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#final-remarks'>Final remarks</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>

Pages: 1 2 3 4 5 6 7 8 9 10