CINXE.COM
Text Extraction with BERT
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/nlp/text_extraction_with_bert/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Text Extraction with BERT"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Text Extraction with BERT"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Text Extraction with BERT</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink active" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink2" href="/examples/nlp/text_classification_from_scratch/">Text classification from scratch</a> <a class="nav-sublink2" href="/examples/nlp/active_learning_review_classification/">Review Classification using Active Learning</a> <a class="nav-sublink2" href="/examples/nlp/fnet_classification_with_keras_hub/">Text Classification using FNet</a> <a class="nav-sublink2" href="/examples/nlp/multi_label_classification/">Large-scale multi-label text classification</a> <a class="nav-sublink2" href="/examples/nlp/text_classification_with_transformer/">Text classification with Transformer</a> <a class="nav-sublink2" href="/examples/nlp/text_classification_with_switch_transformer/">Text classification with Switch Transformer</a> <a class="nav-sublink2" href="/examples/nlp/tweet-classification-using-tfdf/">Text classification using Decision Forests and pretrained embeddings</a> <a class="nav-sublink2" href="/examples/nlp/pretrained_word_embeddings/">Using pre-trained word embeddings</a> <a class="nav-sublink2" href="/examples/nlp/bidirectional_lstm_imdb/">Bidirectional LSTM on IMDB</a> <a class="nav-sublink2" href="/examples/nlp/data_parallel_training_with_keras_hub/">Data Parallel Training with KerasHub and tf.distribute</a> <a class="nav-sublink2" href="/examples/nlp/neural_machine_translation_with_keras_hub/">English-to-Spanish translation with KerasHub</a> <a class="nav-sublink2" href="/examples/nlp/neural_machine_translation_with_transformer/">English-to-Spanish translation with a sequence-to-sequence Transformer</a> <a class="nav-sublink2" href="/examples/nlp/lstm_seq2seq/">Character-level recurrent sequence-to-sequence model</a> <a class="nav-sublink2" href="/examples/nlp/multimodal_entailment/">Multimodal entailment</a> <a class="nav-sublink2" href="/examples/nlp/ner_transformers/">Named Entity Recognition using Transformers</a> <a class="nav-sublink2 active" href="/examples/nlp/text_extraction_with_bert/">Text Extraction with BERT</a> <a class="nav-sublink2" href="/examples/nlp/addition_rnn/">Sequence to sequence learning for performing number addition</a> <a class="nav-sublink2" href="/examples/nlp/semantic_similarity_with_keras_hub/">Semantic Similarity with KerasHub</a> <a class="nav-sublink2" href="/examples/nlp/semantic_similarity_with_bert/">Semantic Similarity with BERT</a> <a class="nav-sublink2" href="/examples/nlp/sentence_embeddings_with_sbert/">Sentence embeddings using Siamese RoBERTa-networks</a> <a class="nav-sublink2" href="/examples/nlp/masked_language_modeling/">End-to-end Masked Language Modeling with BERT</a> <a class="nav-sublink2" href="/examples/nlp/abstractive_summarization_with_bart/">Abstractive Text Summarization with BART</a> <a class="nav-sublink2" href="/examples/nlp/pretraining_BERT/">Pretraining BERT with Hugging Face Transformers</a> <a class="nav-sublink2" href="/examples/nlp/parameter_efficient_finetuning_of_gpt2_with_lora/">Parameter-efficient fine-tuning of GPT-2 with LoRA</a> <a class="nav-sublink2" href="/examples/nlp/multiple_choice_task_with_transfer_learning/">MultipleChoice Task with Transfer Learning</a> <a class="nav-sublink2" href="/examples/nlp/question_answering/">Question Answering with Hugging Face Transformers</a> <a class="nav-sublink2" href="/examples/nlp/t5_hf_summarization/">Abstractive Summarization with Hugging Face Transformers</a> <a class="nav-sublink" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparam Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/nlp/'>Natural Language Processing</a> / Text Extraction with BERT </div> <div class='k-content'> <h1 id="text-extraction-with-bert">Text Extraction with BERT</h1> <p><strong>Author:</strong> <a href="https://twitter.com/NandanApoorv">Apoorv Nandan</a><br> <strong>Date created:</strong> 2020/05/23<br> <strong>Last modified:</strong> 2020/05/23<br></p> <div class='example_version_banner keras_2'>ⓘ This example uses Keras 2</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/nlp/ipynb/text_extraction_with_bert.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/examples/nlp/text_extraction_with_bert.py"><strong>GitHub source</strong></a></p> <p><strong>Description:</strong> Fine tune pretrained BERT from HuggingFace Transformers on SQuAD.</p> <hr /> <h2 id="introduction">Introduction</h2> <p>This demonstration uses SQuAD (Stanford Question-Answering Dataset). In SQuAD, an input consists of a question, and a paragraph for context. The goal is to find the span of text in the paragraph that answers the question. We evaluate our performance on this data with the "Exact Match" metric, which measures the percentage of predictions that exactly match any one of the ground-truth answers.</p> <p>We fine-tune a BERT model to perform this task as follows:</p> <ol> <li>Feed the context and the question as inputs to BERT.</li> <li>Take two vectors S and T with dimensions equal to that of hidden states in BERT.</li> <li>Compute the probability of each token being the start and end of the answer span. The probability of a token being the start of the answer is given by a dot product between S and the representation of the token in the last layer of BERT, followed by a softmax over all tokens. The probability of a token being the end of the answer is computed similarly with the vector T.</li> <li>Fine-tune BERT and learn S and T along the way.</li> </ol> <p><strong>References:</strong></p> <ul> <li><a href="https://arxiv.org/pdf/1810.04805.pdf">BERT</a></li> <li><a href="https://arxiv.org/abs/1606.05250">SQuAD</a></li> </ul> <h2 id="setup">Setup</h2> <div class="codehilite"><pre><span></span><code><span class="kn">import</span><span class="w"> </span><span class="nn">os</span> <span class="kn">import</span><span class="w"> </span><span class="nn">re</span> <span class="kn">import</span><span class="w"> </span><span class="nn">json</span> <span class="kn">import</span><span class="w"> </span><span class="nn">string</span> <span class="kn">import</span><span class="w"> </span><span class="nn">numpy</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">np</span> <span class="kn">import</span><span class="w"> </span><span class="nn">tensorflow</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">tf</span> <span class="kn">from</span><span class="w"> </span><span class="nn">tensorflow</span><span class="w"> </span><span class="kn">import</span> <span class="n">keras</span> <span class="kn">from</span><span class="w"> </span><span class="nn">tensorflow.keras</span><span class="w"> </span><span class="kn">import</span> <span class="n">layers</span> <span class="kn">from</span><span class="w"> </span><span class="nn">tokenizers</span><span class="w"> </span><span class="kn">import</span> <span class="n">BertWordPieceTokenizer</span> <span class="kn">from</span><span class="w"> </span><span class="nn">transformers</span><span class="w"> </span><span class="kn">import</span> <span class="n">BertTokenizer</span><span class="p">,</span> <span class="n">TFBertModel</span><span class="p">,</span> <span class="n">BertConfig</span> <span class="n">max_len</span> <span class="o">=</span> <span class="mi">384</span> <span class="n">configuration</span> <span class="o">=</span> <span class="n">BertConfig</span><span class="p">()</span> <span class="c1"># default parameters and configuration for BERT</span> </code></pre></div> <hr /> <h2 id="setup-bert-tokenizer">Set-up BERT tokenizer</h2> <div class="codehilite"><pre><span></span><code><span class="c1"># Save the slow pretrained tokenizer</span> <span class="n">slow_tokenizer</span> <span class="o">=</span> <span class="n">BertTokenizer</span><span class="o">.</span><span class="n">from_pretrained</span><span class="p">(</span><span class="s2">"bert-base-uncased"</span><span class="p">)</span> <span class="n">save_path</span> <span class="o">=</span> <span class="s2">"bert_base_uncased/"</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">save_path</span><span class="p">):</span> <span class="n">os</span><span class="o">.</span><span class="n">makedirs</span><span class="p">(</span><span class="n">save_path</span><span class="p">)</span> <span class="n">slow_tokenizer</span><span class="o">.</span><span class="n">save_pretrained</span><span class="p">(</span><span class="n">save_path</span><span class="p">)</span> <span class="c1"># Load the fast tokenizer from saved file</span> <span class="n">tokenizer</span> <span class="o">=</span> <span class="n">BertWordPieceTokenizer</span><span class="p">(</span><span class="s2">"bert_base_uncased/vocab.txt"</span><span class="p">,</span> <span class="n">lowercase</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="load-the-data">Load the data</h2> <div class="codehilite"><pre><span></span><code><span class="n">train_data_url</span> <span class="o">=</span> <span class="s2">"https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json"</span> <span class="n">train_path</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">get_file</span><span class="p">(</span><span class="s2">"train.json"</span><span class="p">,</span> <span class="n">train_data_url</span><span class="p">)</span> <span class="n">eval_data_url</span> <span class="o">=</span> <span class="s2">"https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json"</span> <span class="n">eval_path</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">get_file</span><span class="p">(</span><span class="s2">"eval.json"</span><span class="p">,</span> <span class="n">eval_data_url</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="preprocess-the-data">Preprocess the data</h2> <ol> <li>Go through the JSON file and store every record as a <code>SquadExample</code> object.</li> <li>Go through each <code>SquadExample</code> and create <code>x_train, y_train, x_eval, y_eval</code>.</li> </ol> <div class="codehilite"><pre><span></span><code><span class="k">class</span><span class="w"> </span><span class="nc">SquadExample</span><span class="p">:</span> <span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">question</span><span class="p">,</span> <span class="n">context</span><span class="p">,</span> <span class="n">start_char_idx</span><span class="p">,</span> <span class="n">answer_text</span><span class="p">,</span> <span class="n">all_answers</span><span class="p">):</span> <span class="bp">self</span><span class="o">.</span><span class="n">question</span> <span class="o">=</span> <span class="n">question</span> <span class="bp">self</span><span class="o">.</span><span class="n">context</span> <span class="o">=</span> <span class="n">context</span> <span class="bp">self</span><span class="o">.</span><span class="n">start_char_idx</span> <span class="o">=</span> <span class="n">start_char_idx</span> <span class="bp">self</span><span class="o">.</span><span class="n">answer_text</span> <span class="o">=</span> <span class="n">answer_text</span> <span class="bp">self</span><span class="o">.</span><span class="n">all_answers</span> <span class="o">=</span> <span class="n">all_answers</span> <span class="bp">self</span><span class="o">.</span><span class="n">skip</span> <span class="o">=</span> <span class="kc">False</span> <span class="k">def</span><span class="w"> </span><span class="nf">preprocess</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="n">context</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">context</span> <span class="n">question</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">question</span> <span class="n">answer_text</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">answer_text</span> <span class="n">start_char_idx</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">start_char_idx</span> <span class="c1"># Clean context, answer and question</span> <span class="n">context</span> <span class="o">=</span> <span class="s2">" "</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">context</span><span class="p">)</span><span class="o">.</span><span class="n">split</span><span class="p">())</span> <span class="n">question</span> <span class="o">=</span> <span class="s2">" "</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">question</span><span class="p">)</span><span class="o">.</span><span class="n">split</span><span class="p">())</span> <span class="n">answer</span> <span class="o">=</span> <span class="s2">" "</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="nb">str</span><span class="p">(</span><span class="n">answer_text</span><span class="p">)</span><span class="o">.</span><span class="n">split</span><span class="p">())</span> <span class="c1"># Find end character index of answer in context</span> <span class="n">end_char_idx</span> <span class="o">=</span> <span class="n">start_char_idx</span> <span class="o">+</span> <span class="nb">len</span><span class="p">(</span><span class="n">answer</span><span class="p">)</span> <span class="k">if</span> <span class="n">end_char_idx</span> <span class="o">>=</span> <span class="nb">len</span><span class="p">(</span><span class="n">context</span><span class="p">):</span> <span class="bp">self</span><span class="o">.</span><span class="n">skip</span> <span class="o">=</span> <span class="kc">True</span> <span class="k">return</span> <span class="c1"># Mark the character indexes in context that are in answer</span> <span class="n">is_char_in_ans</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="n">context</span><span class="p">)</span> <span class="k">for</span> <span class="n">idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">start_char_idx</span><span class="p">,</span> <span class="n">end_char_idx</span><span class="p">):</span> <span class="n">is_char_in_ans</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span> <span class="c1"># Tokenize context</span> <span class="n">tokenized_context</span> <span class="o">=</span> <span class="n">tokenizer</span><span class="o">.</span><span class="n">encode</span><span class="p">(</span><span class="n">context</span><span class="p">)</span> <span class="c1"># Find tokens that were created from answer characters</span> <span class="n">ans_token_idx</span> <span class="o">=</span> <span class="p">[]</span> <span class="k">for</span> <span class="n">idx</span><span class="p">,</span> <span class="p">(</span><span class="n">start</span><span class="p">,</span> <span class="n">end</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">tokenized_context</span><span class="o">.</span><span class="n">offsets</span><span class="p">):</span> <span class="k">if</span> <span class="nb">sum</span><span class="p">(</span><span class="n">is_char_in_ans</span><span class="p">[</span><span class="n">start</span><span class="p">:</span><span class="n">end</span><span class="p">])</span> <span class="o">></span> <span class="mi">0</span><span class="p">:</span> <span class="n">ans_token_idx</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">idx</span><span class="p">)</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">ans_token_idx</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">skip</span> <span class="o">=</span> <span class="kc">True</span> <span class="k">return</span> <span class="c1"># Find start and end token index for tokens from answer</span> <span class="n">start_token_idx</span> <span class="o">=</span> <span class="n">ans_token_idx</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="n">end_token_idx</span> <span class="o">=</span> <span class="n">ans_token_idx</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="c1"># Tokenize question</span> <span class="n">tokenized_question</span> <span class="o">=</span> <span class="n">tokenizer</span><span class="o">.</span><span class="n">encode</span><span class="p">(</span><span class="n">question</span><span class="p">)</span> <span class="c1"># Create inputs</span> <span class="n">input_ids</span> <span class="o">=</span> <span class="n">tokenized_context</span><span class="o">.</span><span class="n">ids</span> <span class="o">+</span> <span class="n">tokenized_question</span><span class="o">.</span><span class="n">ids</span><span class="p">[</span><span class="mi">1</span><span class="p">:]</span> <span class="n">token_type_ids</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="n">tokenized_context</span><span class="o">.</span><span class="n">ids</span><span class="p">)</span> <span class="o">+</span> <span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span> <span class="n">tokenized_question</span><span class="o">.</span><span class="n">ids</span><span class="p">[</span><span class="mi">1</span><span class="p">:]</span> <span class="p">)</span> <span class="n">attention_mask</span> <span class="o">=</span> <span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="n">input_ids</span><span class="p">)</span> <span class="c1"># Pad and create attention masks.</span> <span class="c1"># Skip if truncation is needed</span> <span class="n">padding_length</span> <span class="o">=</span> <span class="n">max_len</span> <span class="o">-</span> <span class="nb">len</span><span class="p">(</span><span class="n">input_ids</span><span class="p">)</span> <span class="k">if</span> <span class="n">padding_length</span> <span class="o">></span> <span class="mi">0</span><span class="p">:</span> <span class="c1"># pad</span> <span class="n">input_ids</span> <span class="o">=</span> <span class="n">input_ids</span> <span class="o">+</span> <span class="p">([</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">padding_length</span><span class="p">)</span> <span class="n">attention_mask</span> <span class="o">=</span> <span class="n">attention_mask</span> <span class="o">+</span> <span class="p">([</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">padding_length</span><span class="p">)</span> <span class="n">token_type_ids</span> <span class="o">=</span> <span class="n">token_type_ids</span> <span class="o">+</span> <span class="p">([</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">padding_length</span><span class="p">)</span> <span class="k">elif</span> <span class="n">padding_length</span> <span class="o"><</span> <span class="mi">0</span><span class="p">:</span> <span class="c1"># skip</span> <span class="bp">self</span><span class="o">.</span><span class="n">skip</span> <span class="o">=</span> <span class="kc">True</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_ids</span> <span class="o">=</span> <span class="n">input_ids</span> <span class="bp">self</span><span class="o">.</span><span class="n">token_type_ids</span> <span class="o">=</span> <span class="n">token_type_ids</span> <span class="bp">self</span><span class="o">.</span><span class="n">attention_mask</span> <span class="o">=</span> <span class="n">attention_mask</span> <span class="bp">self</span><span class="o">.</span><span class="n">start_token_idx</span> <span class="o">=</span> <span class="n">start_token_idx</span> <span class="bp">self</span><span class="o">.</span><span class="n">end_token_idx</span> <span class="o">=</span> <span class="n">end_token_idx</span> <span class="bp">self</span><span class="o">.</span><span class="n">context_token_to_char</span> <span class="o">=</span> <span class="n">tokenized_context</span><span class="o">.</span><span class="n">offsets</span> <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">train_path</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span> <span class="n">raw_train_data</span> <span class="o">=</span> <span class="n">json</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">f</span><span class="p">)</span> <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">eval_path</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span> <span class="n">raw_eval_data</span> <span class="o">=</span> <span class="n">json</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">f</span><span class="p">)</span> <span class="k">def</span><span class="w"> </span><span class="nf">create_squad_examples</span><span class="p">(</span><span class="n">raw_data</span><span class="p">):</span> <span class="n">squad_examples</span> <span class="o">=</span> <span class="p">[]</span> <span class="k">for</span> <span class="n">item</span> <span class="ow">in</span> <span class="n">raw_data</span><span class="p">[</span><span class="s2">"data"</span><span class="p">]:</span> <span class="k">for</span> <span class="n">para</span> <span class="ow">in</span> <span class="n">item</span><span class="p">[</span><span class="s2">"paragraphs"</span><span class="p">]:</span> <span class="n">context</span> <span class="o">=</span> <span class="n">para</span><span class="p">[</span><span class="s2">"context"</span><span class="p">]</span> <span class="k">for</span> <span class="n">qa</span> <span class="ow">in</span> <span class="n">para</span><span class="p">[</span><span class="s2">"qas"</span><span class="p">]:</span> <span class="n">question</span> <span class="o">=</span> <span class="n">qa</span><span class="p">[</span><span class="s2">"question"</span><span class="p">]</span> <span class="n">answer_text</span> <span class="o">=</span> <span class="n">qa</span><span class="p">[</span><span class="s2">"answers"</span><span class="p">][</span><span class="mi">0</span><span class="p">][</span><span class="s2">"text"</span><span class="p">]</span> <span class="n">all_answers</span> <span class="o">=</span> <span class="p">[</span><span class="n">_</span><span class="p">[</span><span class="s2">"text"</span><span class="p">]</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="n">qa</span><span class="p">[</span><span class="s2">"answers"</span><span class="p">]]</span> <span class="n">start_char_idx</span> <span class="o">=</span> <span class="n">qa</span><span class="p">[</span><span class="s2">"answers"</span><span class="p">][</span><span class="mi">0</span><span class="p">][</span><span class="s2">"answer_start"</span><span class="p">]</span> <span class="n">squad_eg</span> <span class="o">=</span> <span class="n">SquadExample</span><span class="p">(</span> <span class="n">question</span><span class="p">,</span> <span class="n">context</span><span class="p">,</span> <span class="n">start_char_idx</span><span class="p">,</span> <span class="n">answer_text</span><span class="p">,</span> <span class="n">all_answers</span> <span class="p">)</span> <span class="n">squad_eg</span><span class="o">.</span><span class="n">preprocess</span><span class="p">()</span> <span class="n">squad_examples</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">squad_eg</span><span class="p">)</span> <span class="k">return</span> <span class="n">squad_examples</span> <span class="k">def</span><span class="w"> </span><span class="nf">create_inputs_targets</span><span class="p">(</span><span class="n">squad_examples</span><span class="p">):</span> <span class="n">dataset_dict</span> <span class="o">=</span> <span class="p">{</span> <span class="s2">"input_ids"</span><span class="p">:</span> <span class="p">[],</span> <span class="s2">"token_type_ids"</span><span class="p">:</span> <span class="p">[],</span> <span class="s2">"attention_mask"</span><span class="p">:</span> <span class="p">[],</span> <span class="s2">"start_token_idx"</span><span class="p">:</span> <span class="p">[],</span> <span class="s2">"end_token_idx"</span><span class="p">:</span> <span class="p">[],</span> <span class="p">}</span> <span class="k">for</span> <span class="n">item</span> <span class="ow">in</span> <span class="n">squad_examples</span><span class="p">:</span> <span class="k">if</span> <span class="n">item</span><span class="o">.</span><span class="n">skip</span> <span class="o">==</span> <span class="kc">False</span><span class="p">:</span> <span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="n">dataset_dict</span><span class="p">:</span> <span class="n">dataset_dict</span><span class="p">[</span><span class="n">key</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="nb">getattr</span><span class="p">(</span><span class="n">item</span><span class="p">,</span> <span class="n">key</span><span class="p">))</span> <span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="n">dataset_dict</span><span class="p">:</span> <span class="n">dataset_dict</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">dataset_dict</span><span class="p">[</span><span class="n">key</span><span class="p">])</span> <span class="n">x</span> <span class="o">=</span> <span class="p">[</span> <span class="n">dataset_dict</span><span class="p">[</span><span class="s2">"input_ids"</span><span class="p">],</span> <span class="n">dataset_dict</span><span class="p">[</span><span class="s2">"token_type_ids"</span><span class="p">],</span> <span class="n">dataset_dict</span><span class="p">[</span><span class="s2">"attention_mask"</span><span class="p">],</span> <span class="p">]</span> <span class="n">y</span> <span class="o">=</span> <span class="p">[</span><span class="n">dataset_dict</span><span class="p">[</span><span class="s2">"start_token_idx"</span><span class="p">],</span> <span class="n">dataset_dict</span><span class="p">[</span><span class="s2">"end_token_idx"</span><span class="p">]]</span> <span class="k">return</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span> <span class="n">train_squad_examples</span> <span class="o">=</span> <span class="n">create_squad_examples</span><span class="p">(</span><span class="n">raw_train_data</span><span class="p">)</span> <span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span> <span class="o">=</span> <span class="n">create_inputs_targets</span><span class="p">(</span><span class="n">train_squad_examples</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">train_squad_examples</span><span class="p">)</span><span class="si">}</span><span class="s2"> training points created."</span><span class="p">)</span> <span class="n">eval_squad_examples</span> <span class="o">=</span> <span class="n">create_squad_examples</span><span class="p">(</span><span class="n">raw_eval_data</span><span class="p">)</span> <span class="n">x_eval</span><span class="p">,</span> <span class="n">y_eval</span> <span class="o">=</span> <span class="n">create_inputs_targets</span><span class="p">(</span><span class="n">eval_squad_examples</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">eval_squad_examples</span><span class="p">)</span><span class="si">}</span><span class="s2"> evaluation points created."</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>87599 training points created. 10570 evaluation points created. </code></pre></div> </div> <p>Create the Question-Answering Model using BERT and Functional API</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="nf">create_model</span><span class="p">():</span> <span class="c1">## BERT encoder</span> <span class="n">encoder</span> <span class="o">=</span> <span class="n">TFBertModel</span><span class="o">.</span><span class="n">from_pretrained</span><span class="p">(</span><span class="s2">"bert-base-uncased"</span><span class="p">)</span> <span class="c1">## QA Model</span> <span class="n">input_ids</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">max_len</span><span class="p">,),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span> <span class="n">token_type_ids</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">max_len</span><span class="p">,),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span> <span class="n">attention_mask</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">max_len</span><span class="p">,),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span> <span class="n">embedding</span> <span class="o">=</span> <span class="n">encoder</span><span class="p">(</span> <span class="n">input_ids</span><span class="p">,</span> <span class="n">token_type_ids</span><span class="o">=</span><span class="n">token_type_ids</span><span class="p">,</span> <span class="n">attention_mask</span><span class="o">=</span><span class="n">attention_mask</span> <span class="p">)[</span><span class="mi">0</span><span class="p">]</span> <span class="n">start_logits</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"start_logit"</span><span class="p">,</span> <span class="n">use_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)(</span><span class="n">embedding</span><span class="p">)</span> <span class="n">start_logits</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Flatten</span><span class="p">()(</span><span class="n">start_logits</span><span class="p">)</span> <span class="n">end_logits</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"end_logit"</span><span class="p">,</span> <span class="n">use_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)(</span><span class="n">embedding</span><span class="p">)</span> <span class="n">end_logits</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Flatten</span><span class="p">()(</span><span class="n">end_logits</span><span class="p">)</span> <span class="n">start_probs</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Activation</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">activations</span><span class="o">.</span><span class="n">softmax</span><span class="p">)(</span><span class="n">start_logits</span><span class="p">)</span> <span class="n">end_probs</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Activation</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">activations</span><span class="o">.</span><span class="n">softmax</span><span class="p">)(</span><span class="n">end_logits</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span> <span class="n">inputs</span><span class="o">=</span><span class="p">[</span><span class="n">input_ids</span><span class="p">,</span> <span class="n">token_type_ids</span><span class="p">,</span> <span class="n">attention_mask</span><span class="p">],</span> <span class="n">outputs</span><span class="o">=</span><span class="p">[</span><span class="n">start_probs</span><span class="p">,</span> <span class="n">end_probs</span><span class="p">],</span> <span class="p">)</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">SparseCategoricalCrossentropy</span><span class="p">(</span><span class="n">from_logits</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> <span class="n">optimizer</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">lr</span><span class="o">=</span><span class="mf">5e-5</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">optimizer</span><span class="o">=</span><span class="n">optimizer</span><span class="p">,</span> <span class="n">loss</span><span class="o">=</span><span class="p">[</span><span class="n">loss</span><span class="p">,</span> <span class="n">loss</span><span class="p">])</span> <span class="k">return</span> <span class="n">model</span> </code></pre></div> <p>This code should preferably be run on Google Colab TPU runtime. With Colab TPUs, each epoch will take 5-6 minutes.</p> <div class="codehilite"><pre><span></span><code><span class="n">use_tpu</span> <span class="o">=</span> <span class="kc">True</span> <span class="k">if</span> <span class="n">use_tpu</span><span class="p">:</span> <span class="c1"># Create distribution strategy</span> <span class="n">tpu</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">distribute</span><span class="o">.</span><span class="n">cluster_resolver</span><span class="o">.</span><span class="n">TPUClusterResolver</span><span class="o">.</span><span class="n">connect</span><span class="p">()</span> <span class="n">strategy</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">distribute</span><span class="o">.</span><span class="n">TPUStrategy</span><span class="p">(</span><span class="n">tpu</span><span class="p">)</span> <span class="c1"># Create model</span> <span class="k">with</span> <span class="n">strategy</span><span class="o">.</span><span class="n">scope</span><span class="p">():</span> <span class="n">model</span> <span class="o">=</span> <span class="n">create_model</span><span class="p">()</span> <span class="k">else</span><span class="p">:</span> <span class="n">model</span> <span class="o">=</span> <span class="n">create_model</span><span class="p">()</span> <span class="n">model</span><span class="o">.</span><span class="n">summary</span><span class="p">()</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>INFO:absl:Entering into master device scope: /job:worker/replica:0/task:0/device:CPU:0 INFO:tensorflow:Initializing the TPU system: grpc://10.48.159.170:8470 INFO:tensorflow:Clearing out eager caches INFO:tensorflow:Finished initializing TPU system. INFO:tensorflow:Found TPU system: INFO:tensorflow:*** Num TPU Cores: 8 INFO:tensorflow:*** Num TPU Workers: 1 INFO:tensorflow:*** Num TPU Cores Per Worker: 8 Model: "model" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_1 (InputLayer) [(None, 384)] 0 __________________________________________________________________________________________________ input_3 (InputLayer) [(None, 384)] 0 __________________________________________________________________________________________________ input_2 (InputLayer) [(None, 384)] 0 __________________________________________________________________________________________________ tf_bert_model (TFBertModel) ((None, 384, 768), ( 109482240 input_1[0][0] __________________________________________________________________________________________________ start_logit (Dense) (None, 384, 1) 768 tf_bert_model[0][0] __________________________________________________________________________________________________ end_logit (Dense) (None, 384, 1) 768 tf_bert_model[0][0] __________________________________________________________________________________________________ flatten (Flatten) (None, 384) 0 start_logit[0][0] __________________________________________________________________________________________________ flatten_1 (Flatten) (None, 384) 0 end_logit[0][0] __________________________________________________________________________________________________ activation_7 (Activation) (None, 384) 0 flatten[0][0] __________________________________________________________________________________________________ activation_8 (Activation) (None, 384) 0 flatten_1[0][0] ================================================================================================== Total params: 109,483,776 Trainable params: 109,483,776 Non-trainable params: 0 __________________________________________________________________________________________________ </code></pre></div> </div> <hr /> <h2 id="create-evaluation-callback">Create evaluation Callback</h2> <p>This callback will compute the exact match score using the validation data after every epoch.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="nf">normalize_text</span><span class="p">(</span><span class="n">text</span><span class="p">):</span> <span class="n">text</span> <span class="o">=</span> <span class="n">text</span><span class="o">.</span><span class="n">lower</span><span class="p">()</span> <span class="c1"># Remove punctuations</span> <span class="n">exclude</span> <span class="o">=</span> <span class="nb">set</span><span class="p">(</span><span class="n">string</span><span class="o">.</span><span class="n">punctuation</span><span class="p">)</span> <span class="n">text</span> <span class="o">=</span> <span class="s2">""</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">ch</span> <span class="k">for</span> <span class="n">ch</span> <span class="ow">in</span> <span class="n">text</span> <span class="k">if</span> <span class="n">ch</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">exclude</span><span class="p">)</span> <span class="c1"># Remove articles</span> <span class="n">regex</span> <span class="o">=</span> <span class="n">re</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="sa">r</span><span class="s2">"\b(a|an|the)\b"</span><span class="p">,</span> <span class="n">re</span><span class="o">.</span><span class="n">UNICODE</span><span class="p">)</span> <span class="n">text</span> <span class="o">=</span> <span class="n">re</span><span class="o">.</span><span class="n">sub</span><span class="p">(</span><span class="n">regex</span><span class="p">,</span> <span class="s2">" "</span><span class="p">,</span> <span class="n">text</span><span class="p">)</span> <span class="c1"># Remove extra white space</span> <span class="n">text</span> <span class="o">=</span> <span class="s2">" "</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">text</span><span class="o">.</span><span class="n">split</span><span class="p">())</span> <span class="k">return</span> <span class="n">text</span> <span class="k">class</span><span class="w"> </span><span class="nc">ExactMatch</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">callbacks</span><span class="o">.</span><span class="n">Callback</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""</span> <span class="sd"> Each `SquadExample` object contains the character level offsets for each token</span> <span class="sd"> in its input paragraph. We use them to get back the span of text corresponding</span> <span class="sd"> to the tokens between our predicted start and end tokens.</span> <span class="sd"> All the ground-truth answers are also present in each `SquadExample` object.</span> <span class="sd"> We calculate the percentage of data points where the span of text obtained</span> <span class="sd"> from model predictions matches one of the ground-truth answers.</span> <span class="sd"> """</span> <span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x_eval</span><span class="p">,</span> <span class="n">y_eval</span><span class="p">):</span> <span class="bp">self</span><span class="o">.</span><span class="n">x_eval</span> <span class="o">=</span> <span class="n">x_eval</span> <span class="bp">self</span><span class="o">.</span><span class="n">y_eval</span> <span class="o">=</span> <span class="n">y_eval</span> <span class="k">def</span><span class="w"> </span><span class="nf">on_epoch_end</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">epoch</span><span class="p">,</span> <span class="n">logs</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="n">pred_start</span><span class="p">,</span> <span class="n">pred_end</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">x_eval</span><span class="p">)</span> <span class="n">count</span> <span class="o">=</span> <span class="mi">0</span> <span class="n">eval_examples_no_skip</span> <span class="o">=</span> <span class="p">[</span><span class="n">_</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="n">eval_squad_examples</span> <span class="k">if</span> <span class="n">_</span><span class="o">.</span><span class="n">skip</span> <span class="o">==</span> <span class="kc">False</span><span class="p">]</span> <span class="k">for</span> <span class="n">idx</span><span class="p">,</span> <span class="p">(</span><span class="n">start</span><span class="p">,</span> <span class="n">end</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="n">pred_start</span><span class="p">,</span> <span class="n">pred_end</span><span class="p">)):</span> <span class="n">squad_eg</span> <span class="o">=</span> <span class="n">eval_examples_no_skip</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span> <span class="n">offsets</span> <span class="o">=</span> <span class="n">squad_eg</span><span class="o">.</span><span class="n">context_token_to_char</span> <span class="n">start</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">start</span><span class="p">)</span> <span class="n">end</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">end</span><span class="p">)</span> <span class="k">if</span> <span class="n">start</span> <span class="o">>=</span> <span class="nb">len</span><span class="p">(</span><span class="n">offsets</span><span class="p">):</span> <span class="k">continue</span> <span class="n">pred_char_start</span> <span class="o">=</span> <span class="n">offsets</span><span class="p">[</span><span class="n">start</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span> <span class="k">if</span> <span class="n">end</span> <span class="o"><</span> <span class="nb">len</span><span class="p">(</span><span class="n">offsets</span><span class="p">):</span> <span class="n">pred_char_end</span> <span class="o">=</span> <span class="n">offsets</span><span class="p">[</span><span class="n">end</span><span class="p">][</span><span class="mi">1</span><span class="p">]</span> <span class="n">pred_ans</span> <span class="o">=</span> <span class="n">squad_eg</span><span class="o">.</span><span class="n">context</span><span class="p">[</span><span class="n">pred_char_start</span><span class="p">:</span><span class="n">pred_char_end</span><span class="p">]</span> <span class="k">else</span><span class="p">:</span> <span class="n">pred_ans</span> <span class="o">=</span> <span class="n">squad_eg</span><span class="o">.</span><span class="n">context</span><span class="p">[</span><span class="n">pred_char_start</span><span class="p">:]</span> <span class="n">normalized_pred_ans</span> <span class="o">=</span> <span class="n">normalize_text</span><span class="p">(</span><span class="n">pred_ans</span><span class="p">)</span> <span class="n">normalized_true_ans</span> <span class="o">=</span> <span class="p">[</span><span class="n">normalize_text</span><span class="p">(</span><span class="n">_</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="n">squad_eg</span><span class="o">.</span><span class="n">all_answers</span><span class="p">]</span> <span class="k">if</span> <span class="n">normalized_pred_ans</span> <span class="ow">in</span> <span class="n">normalized_true_ans</span><span class="p">:</span> <span class="n">count</span> <span class="o">+=</span> <span class="mi">1</span> <span class="n">acc</span> <span class="o">=</span> <span class="n">count</span> <span class="o">/</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">y_eval</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="se">\n</span><span class="s2">epoch=</span><span class="si">{</span><span class="n">epoch</span><span class="o">+</span><span class="mi">1</span><span class="si">}</span><span class="s2">, exact match score=</span><span class="si">{</span><span class="n">acc</span><span class="si">:</span><span class="s2">.2f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="train-and-evaluate">Train and Evaluate</h2> <div class="codehilite"><pre><span></span><code><span class="n">exact_match_callback</span> <span class="o">=</span> <span class="n">ExactMatch</span><span class="p">(</span><span class="n">x_eval</span><span class="p">,</span> <span class="n">y_eval</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span> <span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="c1"># For demonstration, 3 epochs are recommended</span> <span class="n">verbose</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">callbacks</span><span class="o">=</span><span class="p">[</span><span class="n">exact_match_callback</span><span class="p">],</span> <span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>epoch=1, exact match score=0.78 1346/1346 - 350s - activation_7_loss: 1.3488 - loss: 2.5905 - activation_8_loss: 1.2417 <tensorflow.python.keras.callbacks.History at 0x7fc78b4458d0> </code></pre></div> </div> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#text-extraction-with-bert'>Text Extraction with BERT</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#setup'>Setup</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#setup-bert-tokenizer'>Set-up BERT tokenizer</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#load-the-data'>Load the data</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#preprocess-the-data'>Preprocess the data</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#create-evaluation-callback'>Create evaluation Callback</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#train-and-evaluate'>Train and Evaluate</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>