CINXE.COM

Pretraining BERT with Hugging Face Transformers

<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/nlp/pretraining_BERT/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Pretraining BERT with Hugging Face Transformers"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Pretraining BERT with Hugging Face Transformers"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Pretraining BERT with Hugging Face Transformers</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink active" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink2" href="/examples/nlp/text_classification_from_scratch/">Text classification from scratch</a> <a class="nav-sublink2" href="/examples/nlp/active_learning_review_classification/">Review Classification using Active Learning</a> <a class="nav-sublink2" href="/examples/nlp/fnet_classification_with_keras_hub/">Text Classification using FNet</a> <a class="nav-sublink2" href="/examples/nlp/multi_label_classification/">Large-scale multi-label text classification</a> <a class="nav-sublink2" href="/examples/nlp/text_classification_with_transformer/">Text classification with Transformer</a> <a class="nav-sublink2" href="/examples/nlp/text_classification_with_switch_transformer/">Text classification with Switch Transformer</a> <a class="nav-sublink2" href="/examples/nlp/tweet-classification-using-tfdf/">Text classification using Decision Forests and pretrained embeddings</a> <a class="nav-sublink2" href="/examples/nlp/pretrained_word_embeddings/">Using pre-trained word embeddings</a> <a class="nav-sublink2" href="/examples/nlp/bidirectional_lstm_imdb/">Bidirectional LSTM on IMDB</a> <a class="nav-sublink2" href="/examples/nlp/data_parallel_training_with_keras_hub/">Data Parallel Training with KerasHub and tf.distribute</a> <a class="nav-sublink2" href="/examples/nlp/neural_machine_translation_with_keras_hub/">English-to-Spanish translation with KerasHub</a> <a class="nav-sublink2" href="/examples/nlp/neural_machine_translation_with_transformer/">English-to-Spanish translation with a sequence-to-sequence Transformer</a> <a class="nav-sublink2" href="/examples/nlp/lstm_seq2seq/">Character-level recurrent sequence-to-sequence model</a> <a class="nav-sublink2" href="/examples/nlp/multimodal_entailment/">Multimodal entailment</a> <a class="nav-sublink2" href="/examples/nlp/ner_transformers/">Named Entity Recognition using Transformers</a> <a class="nav-sublink2" href="/examples/nlp/text_extraction_with_bert/">Text Extraction with BERT</a> <a class="nav-sublink2" href="/examples/nlp/addition_rnn/">Sequence to sequence learning for performing number addition</a> <a class="nav-sublink2" href="/examples/nlp/semantic_similarity_with_keras_hub/">Semantic Similarity with KerasHub</a> <a class="nav-sublink2" href="/examples/nlp/semantic_similarity_with_bert/">Semantic Similarity with BERT</a> <a class="nav-sublink2" href="/examples/nlp/sentence_embeddings_with_sbert/">Sentence embeddings using Siamese RoBERTa-networks</a> <a class="nav-sublink2" href="/examples/nlp/masked_language_modeling/">End-to-end Masked Language Modeling with BERT</a> <a class="nav-sublink2" href="/examples/nlp/abstractive_summarization_with_bart/">Abstractive Text Summarization with BART</a> <a class="nav-sublink2 active" href="/examples/nlp/pretraining_BERT/">Pretraining BERT with Hugging Face Transformers</a> <a class="nav-sublink2" href="/examples/nlp/parameter_efficient_finetuning_of_gpt2_with_lora/">Parameter-efficient fine-tuning of GPT-2 with LoRA</a> <a class="nav-sublink2" href="/examples/nlp/multiple_choice_task_with_transfer_learning/">MultipleChoice Task with Transfer Learning</a> <a class="nav-sublink2" href="/examples/nlp/question_answering/">Question Answering with Hugging Face Transformers</a> <a class="nav-sublink2" href="/examples/nlp/t5_hf_summarization/">Abstractive Summarization with Hugging Face Transformers</a> <a class="nav-sublink" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparam Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/nlp/'>Natural Language Processing</a> / Pretraining BERT with Hugging Face Transformers </div> <div class='k-content'> <h1 id="pretraining-bert-with-hugging-face-transformers">Pretraining BERT with Hugging Face Transformers</h1> <p><strong>Author:</strong> Sreyan Ghosh<br> <strong>Date created:</strong> 2022/07/01<br> <strong>Last modified:</strong> 2022/08/27<br> <strong>Description:</strong> Pretraining BERT using Hugging Face Transformers on NSP and MLM.</p> <div class='example_version_banner keras_2'>ⓘ This example uses Keras 2</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/pretraining/ipynb/pretraining_BERT.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/examples/pretraining/pretraining_BERT.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="introduction">Introduction</h2> <h3 id="bert-bidirectional-encoder-representations-from-transformers">BERT (Bidirectional Encoder Representations from Transformers)</h3> <p>In the field of computer vision, researchers have repeatedly shown the value of transfer learning — pretraining a neural network model on a known task/dataset, for instance ImageNet classification, and then performing fine-tuning — using the trained neural network as the basis of a new specific-purpose model. In recent years, researchers have shown that a similar technique can be useful in many natural language tasks.</p> <p>BERT makes use of Transformer, an attention mechanism that learns contextual relations between words (or subwords) in a text. In its vanilla form, Transformer includes two separate mechanisms — an encoder that reads the text input and a decoder that produces a prediction for the task. Since BERT’s goal is to generate a language model, only the encoder mechanism is necessary. The detailed workings of Transformer are described in a paper by Google.</p> <p>As opposed to directional models, which read the text input sequentially (left-to-right or right-to-left), the Transformer encoder reads the entire sequence of words at once. Therefore it is considered bidirectional, though it would be more accurate to say that it’s non-directional. This characteristic allows the model to learn the context of a word based on all of its surroundings (left and right of the word).</p> <p>When training language models, a challenge is defining a prediction goal. Many models predict the next word in a sequence (e.g. <code>"The child came home from _"</code>), a directional approach which inherently limits context learning. To overcome this challenge, BERT uses two training strategies:</p> <h3 id="masked-language-modeling-mlm">Masked Language Modeling (MLM)</h3> <p>Before feeding word sequences into BERT, 15% of the words in each sequence are replaced with a <code>[MASK]</code> token. The model then attempts to predict the original value of the masked words, based on the context provided by the other, non-masked, words in the sequence.</p> <h3 id="next-sentence-prediction-nsp">Next Sentence Prediction (NSP)</h3> <p>In the BERT training process, the model receives pairs of sentences as input and learns to predict if the second sentence in the pair is the subsequent sentence in the original document. During training, 50% of the inputs are a pair in which the second sentence is the subsequent sentence in the original document, while in the other 50% a random sentence from the corpus is chosen as the second sentence. The assumption is that the random sentence will represent a disconnect from the first sentence.</p> <p>Though Google provides a pretrained BERT checkpoint for English, you may often need to either pretrain the model from scratch for a different language, or do a continued-pretraining to fit the model to a new domain. In this notebook, we pretrain BERT from scratch optimizing both MLM and NSP objectves using 🤗 Transformers on the <code>WikiText</code> English dataset loaded from 🤗 Datasets.</p> <hr /> <h2 id="setup">Setup</h2> <h3 id="installing-the-requirements">Installing the requirements</h3> <div class="codehilite"><pre><span></span><code><span class="n">pip</span> <span class="n">install</span> <span class="n">git</span><span class="o">+</span><span class="n">https</span><span class="p">:</span><span class="o">//</span><span class="n">github</span><span class="o">.</span><span class="n">com</span><span class="o">/</span><span class="n">huggingface</span><span class="o">/</span><span class="n">transformers</span><span class="o">.</span><span class="n">git</span> <span class="n">pip</span> <span class="n">install</span> <span class="n">datasets</span> <span class="n">pip</span> <span class="n">install</span> <span class="n">huggingface</span><span class="o">-</span><span class="n">hub</span> <span class="n">pip</span> <span class="n">install</span> <span class="n">nltk</span> </code></pre></div> <h3 id="importing-the-necessary-libraries">Importing the necessary libraries</h3> <div class="codehilite"><pre><span></span><code><span class="kn">import</span><span class="w"> </span><span class="nn">nltk</span> <span class="kn">import</span><span class="w"> </span><span class="nn">random</span> <span class="kn">import</span><span class="w"> </span><span class="nn">logging</span> <span class="kn">import</span><span class="w"> </span><span class="nn">tensorflow</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">tf</span> <span class="kn">from</span><span class="w"> </span><span class="nn">tensorflow</span><span class="w"> </span><span class="kn">import</span> <span class="n">keras</span> <span class="n">nltk</span><span class="o">.</span><span class="n">download</span><span class="p">(</span><span class="s2">&quot;punkt&quot;</span><span class="p">)</span> <span class="c1"># Only log error messages</span> <span class="n">tf</span><span class="o">.</span><span class="n">get_logger</span><span class="p">()</span><span class="o">.</span><span class="n">setLevel</span><span class="p">(</span><span class="n">logging</span><span class="o">.</span><span class="n">ERROR</span><span class="p">)</span> <span class="c1"># Set random seed</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">set_random_seed</span><span class="p">(</span><span class="mi">42</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>[nltk_data] Downloading package punkt to /speech/sreyan/nltk_data... [nltk_data] Package punkt is already up-to-date! </code></pre></div> </div> <h3 id="define-certain-variables">Define certain variables</h3> <div class="codehilite"><pre><span></span><code><span class="n">TOKENIZER_BATCH_SIZE</span> <span class="o">=</span> <span class="mi">256</span> <span class="c1"># Batch-size to train the tokenizer on</span> <span class="n">TOKENIZER_VOCABULARY</span> <span class="o">=</span> <span class="mi">25000</span> <span class="c1"># Total number of unique subwords the tokenizer can have</span> <span class="n">BLOCK_SIZE</span> <span class="o">=</span> <span class="mi">128</span> <span class="c1"># Maximum number of tokens in an input sample</span> <span class="n">NSP_PROB</span> <span class="o">=</span> <span class="mf">0.50</span> <span class="c1"># Probability that the next sentence is the actual next sentence in NSP</span> <span class="n">SHORT_SEQ_PROB</span> <span class="o">=</span> <span class="mf">0.1</span> <span class="c1"># Probability of generating shorter sequences to minimize the mismatch between pretraining and fine-tuning.</span> <span class="n">MAX_LENGTH</span> <span class="o">=</span> <span class="mi">512</span> <span class="c1"># Maximum number of tokens in an input sample after padding</span> <span class="n">MLM_PROB</span> <span class="o">=</span> <span class="mf">0.2</span> <span class="c1"># Probability with which tokens are masked in MLM</span> <span class="n">TRAIN_BATCH_SIZE</span> <span class="o">=</span> <span class="mi">2</span> <span class="c1"># Batch-size for pretraining the model on</span> <span class="n">MAX_EPOCHS</span> <span class="o">=</span> <span class="mi">1</span> <span class="c1"># Maximum number of epochs to train the model for</span> <span class="n">LEARNING_RATE</span> <span class="o">=</span> <span class="mf">1e-4</span> <span class="c1"># Learning rate for training the model</span> <span class="n">MODEL_CHECKPOINT</span> <span class="o">=</span> <span class="s2">&quot;bert-base-cased&quot;</span> <span class="c1"># Name of pretrained model from 🤗 Model Hub</span> </code></pre></div> <hr /> <h2 id="load-the-wikitext-dataset">Load the WikiText dataset</h2> <p>We now download the <code>WikiText</code> language modeling dataset. It is a collection of over 100 million tokens extracted from the set of verified "Good" and "Featured" articles on Wikipedia.</p> <p>We load the dataset from <a href="https://github.com/huggingface/datasets">🤗 Datasets</a>. For the purpose of demonstration in this notebook, we work with only the <code>train</code> split of the dataset. This can be easily done with the <code>load_dataset</code> function.</p> <div class="codehilite"><pre><span></span><code><span class="kn">from</span><span class="w"> </span><span class="nn">datasets</span><span class="w"> </span><span class="kn">import</span> <span class="n">load_dataset</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">load_dataset</span><span class="p">(</span><span class="s2">&quot;wikitext&quot;</span><span class="p">,</span> <span class="s2">&quot;wikitext-2-raw-v1&quot;</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Downloading and preparing dataset wikitext/wikitext-2-raw-v1 (download: 4.50 MiB, generated: 12.90 MiB, post-processed: Unknown size, total: 17.40 MiB) to /speech/sreyan/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126... Downloading data: 0%| | 0.00/4.72M [00:00&lt;?, ?B/s] Generating test split: 0%| | 0/4358 [00:00&lt;?, ? examples/s] Generating train split: 0%| | 0/36718 [00:00&lt;?, ? examples/s] Generating validation split: 0%| | 0/3760 [00:00&lt;?, ? examples/s] Dataset wikitext downloaded and prepared to /speech/sreyan/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126. Subsequent calls will reuse this data. 0%| | 0/3 [00:00&lt;?, ?it/s] </code></pre></div> </div> <p>The dataset just has one column which is the raw text, and this is all we need for pretraining BERT!</p> <div class="codehilite"><pre><span></span><code><span class="nb">print</span><span class="p">(</span><span class="n">dataset</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>DatasetDict({ test: Dataset({ features: [&#39;text&#39;], num_rows: 4358 }) train: Dataset({ features: [&#39;text&#39;], num_rows: 36718 }) validation: Dataset({ features: [&#39;text&#39;], num_rows: 3760 }) }) </code></pre></div> </div> <hr /> <h2 id="training-a-new-tokenizer">Training a new Tokenizer</h2> <p>First we train our own tokenizer from scratch on our corpus, so that can we can use it to train our language model from scratch.</p> <p>But why would you need to train a tokenizer? That's because Transformer models very often use subword tokenization algorithms, and they need to be trained to identify the parts of words that are often present in the corpus you are using.</p> <p>The 🤗 Transformers <code>Tokenizer</code> (as the name indicates) will tokenize the inputs (including converting the tokens to their corresponding IDs in the pretrained vocabulary) and put it in a format the model expects, as well as generate the other inputs that model requires.</p> <p>First we make a list of all the raw documents from the <code>WikiText</code> corpus:</p> <div class="codehilite"><pre><span></span><code><span class="n">all_texts</span> <span class="o">=</span> <span class="p">[</span> <span class="n">doc</span> <span class="k">for</span> <span class="n">doc</span> <span class="ow">in</span> <span class="n">dataset</span><span class="p">[</span><span class="s2">&quot;train&quot;</span><span class="p">][</span><span class="s2">&quot;text&quot;</span><span class="p">]</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">doc</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">doc</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="s2">&quot; =&quot;</span><span class="p">)</span> <span class="p">]</span> </code></pre></div> <p>Next we make a <code>batch_iterator</code> function that will aid us to train our tokenizer.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="nf">batch_iterator</span><span class="p">():</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">all_texts</span><span class="p">),</span> <span class="n">TOKENIZER_BATCH_SIZE</span><span class="p">):</span> <span class="k">yield</span> <span class="n">all_texts</span><span class="p">[</span><span class="n">i</span> <span class="p">:</span> <span class="n">i</span> <span class="o">+</span> <span class="n">TOKENIZER_BATCH_SIZE</span><span class="p">]</span> </code></pre></div> <p>In this notebook, we train a tokenizer with the exact same algorithms and parameters as an existing one. For instance, we train a new version of the <code>BERT-CASED</code> tokenzier on <code>Wikitext-2</code> using the same tokenization algorithm.</p> <p>First we need to load the tokenizer we want to use as a model:</p> <div class="codehilite"><pre><span></span><code><span class="kn">from</span><span class="w"> </span><span class="nn">transformers</span><span class="w"> </span><span class="kn">import</span> <span class="n">AutoTokenizer</span> <span class="n">tokenizer</span> <span class="o">=</span> <span class="n">AutoTokenizer</span><span class="o">.</span><span class="n">from_pretrained</span><span class="p">(</span><span class="n">MODEL_CHECKPOINT</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`. Moving 52 files to the new cache system 0%| | 0/52 [00:00&lt;?, ?it/s] vocab_file vocab.txt tokenizer_file tokenizer.json added_tokens_file added_tokens.json special_tokens_map_file special_tokens_map.json tokenizer_config_file tokenizer_config.json </code></pre></div> </div> <p>Now we train our tokenizer using the entire <code>train</code> split of the <code>Wikitext-2</code> dataset.</p> <div class="codehilite"><pre><span></span><code><span class="n">tokenizer</span> <span class="o">=</span> <span class="n">tokenizer</span><span class="o">.</span><span class="n">train_new_from_iterator</span><span class="p">(</span> <span class="n">batch_iterator</span><span class="p">(),</span> <span class="n">vocab_size</span><span class="o">=</span><span class="n">TOKENIZER_VOCABULARY</span> <span class="p">)</span> </code></pre></div> <p>So now we our done training our new tokenizer! Next we move on to the data pre-processing steps.</p> <hr /> <h2 id="data-preprocessing">Data Pre-processing</h2> <p>For the sake of demonstrating the workflow, in this notebook we only take small subsets of the entire WikiText <code>train</code> and <code>test</code> splits.</p> <div class="codehilite"><pre><span></span><code><span class="n">dataset</span><span class="p">[</span><span class="s2">&quot;train&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">dataset</span><span class="p">[</span><span class="s2">&quot;train&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">select</span><span class="p">([</span><span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1000</span><span class="p">)])</span> <span class="n">dataset</span><span class="p">[</span><span class="s2">&quot;validation&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">dataset</span><span class="p">[</span><span class="s2">&quot;validation&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">select</span><span class="p">([</span><span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1000</span><span class="p">)])</span> </code></pre></div> <p>Before we can feed those texts to our model, we need to pre-process them and get them ready for the task. As mentioned earlier, the BERT pretraining task includes two tasks in total, the <code>NSP</code> task and the <code>MLM</code> task. 🤗 Transformers have an easy to implement <code>collator</code> called the <code>DataCollatorForLanguageModeling</code>. However, we need to get the data ready for <code>NSP</code> manually.</p> <p>Next we write a simple function called the <code>prepare_train_features</code> that helps us in the pre-processing and is compatible with 🤗 Datasets. To summarize, our pre-processing function should:</p> <ul> <li>Get the dataset ready for the NSP task by creating pairs of sentences (A,B), where B either actually follows A, or B is randomly sampled from somewhere else in the corpus. It should also generate a corresponding label for each pair, which is 1 if B actually follows A and 0 if not.</li> <li>Tokenize the text dataset into it's corresponding token ids that will be used for embedding look-up in BERT</li> <li>Create additional inputs for the model like <code>token_type_ids</code>, <code>attention_mask</code>, etc.</li> </ul> <div class="codehilite"><pre><span></span><code><span class="c1"># We define the maximum number of tokens after tokenization that each training sample</span> <span class="c1"># will have</span> <span class="n">max_num_tokens</span> <span class="o">=</span> <span class="n">BLOCK_SIZE</span> <span class="o">-</span> <span class="n">tokenizer</span><span class="o">.</span><span class="n">num_special_tokens_to_add</span><span class="p">(</span><span class="n">pair</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="k">def</span><span class="w"> </span><span class="nf">prepare_train_features</span><span class="p">(</span><span class="n">examples</span><span class="p">):</span> <span class="w"> </span><span class="sd">&quot;&quot;&quot;Function to prepare features for NSP task</span> <span class="sd"> Arguments:</span> <span class="sd"> examples: A dictionary with 1 key (&quot;text&quot;)</span> <span class="sd"> text: List of raw documents (str)</span> <span class="sd"> Returns:</span> <span class="sd"> examples: A dictionary with 4 keys</span> <span class="sd"> input_ids: List of tokenized, concatnated, and batched</span> <span class="sd"> sentences from the individual raw documents (int)</span> <span class="sd"> token_type_ids: List of integers (0 or 1) corresponding</span> <span class="sd"> to: 0 for senetence no. 1 and padding, 1 for sentence</span> <span class="sd"> no. 2</span> <span class="sd"> attention_mask: List of integers (0 or 1) corresponding</span> <span class="sd"> to: 1 for non-padded tokens, 0 for padded</span> <span class="sd"> next_sentence_label: List of integers (0 or 1) corresponding</span> <span class="sd"> to: 1 if the second sentence actually follows the first,</span> <span class="sd"> 0 if the senetence is sampled from somewhere else in the corpus</span> <span class="sd"> &quot;&quot;&quot;</span> <span class="c1"># Remove un-wanted samples from the training set</span> <span class="n">examples</span><span class="p">[</span><span class="s2">&quot;document&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="p">[</span> <span class="n">d</span><span class="o">.</span><span class="n">strip</span><span class="p">()</span> <span class="k">for</span> <span class="n">d</span> <span class="ow">in</span> <span class="n">examples</span><span class="p">[</span><span class="s2">&quot;text&quot;</span><span class="p">]</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">d</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="ow">and</span> <span class="ow">not</span> <span class="n">d</span><span class="o">.</span><span class="n">startswith</span><span class="p">(</span><span class="s2">&quot; =&quot;</span><span class="p">)</span> <span class="p">]</span> <span class="c1"># Split the documents from the dataset into it&#39;s individual sentences</span> <span class="n">examples</span><span class="p">[</span><span class="s2">&quot;sentences&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="p">[</span> <span class="n">nltk</span><span class="o">.</span><span class="n">tokenize</span><span class="o">.</span><span class="n">sent_tokenize</span><span class="p">(</span><span class="n">document</span><span class="p">)</span> <span class="k">for</span> <span class="n">document</span> <span class="ow">in</span> <span class="n">examples</span><span class="p">[</span><span class="s2">&quot;document&quot;</span><span class="p">]</span> <span class="p">]</span> <span class="c1"># Convert the tokens into ids using the trained tokenizer</span> <span class="n">examples</span><span class="p">[</span><span class="s2">&quot;tokenized_sentences&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="p">[</span> <span class="p">[</span><span class="n">tokenizer</span><span class="o">.</span><span class="n">convert_tokens_to_ids</span><span class="p">(</span><span class="n">tokenizer</span><span class="o">.</span><span class="n">tokenize</span><span class="p">(</span><span class="n">sent</span><span class="p">))</span> <span class="k">for</span> <span class="n">sent</span> <span class="ow">in</span> <span class="n">doc</span><span class="p">]</span> <span class="k">for</span> <span class="n">doc</span> <span class="ow">in</span> <span class="n">examples</span><span class="p">[</span><span class="s2">&quot;sentences&quot;</span><span class="p">]</span> <span class="p">]</span> <span class="c1"># Define the outputs</span> <span class="n">examples</span><span class="p">[</span><span class="s2">&quot;input_ids&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="p">[]</span> <span class="n">examples</span><span class="p">[</span><span class="s2">&quot;token_type_ids&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="p">[]</span> <span class="n">examples</span><span class="p">[</span><span class="s2">&quot;attention_mask&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="p">[]</span> <span class="n">examples</span><span class="p">[</span><span class="s2">&quot;next_sentence_label&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="p">[]</span> <span class="k">for</span> <span class="n">doc_index</span><span class="p">,</span> <span class="n">document</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">examples</span><span class="p">[</span><span class="s2">&quot;tokenized_sentences&quot;</span><span class="p">]):</span> <span class="n">current_chunk</span> <span class="o">=</span> <span class="p">[]</span> <span class="c1"># a buffer stored current working segments</span> <span class="n">current_length</span> <span class="o">=</span> <span class="mi">0</span> <span class="n">i</span> <span class="o">=</span> <span class="mi">0</span> <span class="c1"># We *usually* want to fill up the entire sequence since we are padding</span> <span class="c1"># to `block_size` anyways, so short sequences are generally wasted</span> <span class="c1"># computation. However, we *sometimes*</span> <span class="c1"># (i.e., short_seq_prob == 0.1 == 10% of the time) want to use shorter</span> <span class="c1"># sequences to minimize the mismatch between pretraining and fine-tuning.</span> <span class="c1"># The `target_seq_length` is just a rough target however, whereas</span> <span class="c1"># `block_size` is a hard limit.</span> <span class="n">target_seq_length</span> <span class="o">=</span> <span class="n">max_num_tokens</span> <span class="k">if</span> <span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">()</span> <span class="o">&lt;</span> <span class="n">SHORT_SEQ_PROB</span><span class="p">:</span> <span class="n">target_seq_length</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">max_num_tokens</span><span class="p">)</span> <span class="k">while</span> <span class="n">i</span> <span class="o">&lt;</span> <span class="nb">len</span><span class="p">(</span><span class="n">document</span><span class="p">):</span> <span class="n">segment</span> <span class="o">=</span> <span class="n">document</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="n">current_chunk</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">segment</span><span class="p">)</span> <span class="n">current_length</span> <span class="o">+=</span> <span class="nb">len</span><span class="p">(</span><span class="n">segment</span><span class="p">)</span> <span class="k">if</span> <span class="n">i</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span><span class="n">document</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span> <span class="ow">or</span> <span class="n">current_length</span> <span class="o">&gt;=</span> <span class="n">target_seq_length</span><span class="p">:</span> <span class="k">if</span> <span class="n">current_chunk</span><span class="p">:</span> <span class="c1"># `a_end` is how many segments from `current_chunk` go into the `A`</span> <span class="c1"># (first) sentence.</span> <span class="n">a_end</span> <span class="o">=</span> <span class="mi">1</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">current_chunk</span><span class="p">)</span> <span class="o">&gt;=</span> <span class="mi">2</span><span class="p">:</span> <span class="n">a_end</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">current_chunk</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="n">tokens_a</span> <span class="o">=</span> <span class="p">[]</span> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">a_end</span><span class="p">):</span> <span class="n">tokens_a</span><span class="o">.</span><span class="n">extend</span><span class="p">(</span><span class="n">current_chunk</span><span class="p">[</span><span class="n">j</span><span class="p">])</span> <span class="n">tokens_b</span> <span class="o">=</span> <span class="p">[]</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">current_chunk</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span> <span class="ow">or</span> <span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">()</span> <span class="o">&lt;</span> <span class="n">NSP_PROB</span><span class="p">:</span> <span class="n">is_random_next</span> <span class="o">=</span> <span class="kc">True</span> <span class="n">target_b_length</span> <span class="o">=</span> <span class="n">target_seq_length</span> <span class="o">-</span> <span class="nb">len</span><span class="p">(</span><span class="n">tokens_a</span><span class="p">)</span> <span class="c1"># This should rarely go for more than one iteration for large</span> <span class="c1"># corpora. However, just to be careful, we try to make sure that</span> <span class="c1"># the random document is not the same as the document</span> <span class="c1"># we&#39;re processing.</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">10</span><span class="p">):</span> <span class="n">random_document_index</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span> <span class="mi">0</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">examples</span><span class="p">[</span><span class="s2">&quot;tokenized_sentences&quot;</span><span class="p">])</span> <span class="o">-</span> <span class="mi">1</span> <span class="p">)</span> <span class="k">if</span> <span class="n">random_document_index</span> <span class="o">!=</span> <span class="n">doc_index</span><span class="p">:</span> <span class="k">break</span> <span class="n">random_document</span> <span class="o">=</span> <span class="n">examples</span><span class="p">[</span><span class="s2">&quot;tokenized_sentences&quot;</span><span class="p">][</span> <span class="n">random_document_index</span> <span class="p">]</span> <span class="n">random_start</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">random_document</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">random_start</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">random_document</span><span class="p">)):</span> <span class="n">tokens_b</span><span class="o">.</span><span class="n">extend</span><span class="p">(</span><span class="n">random_document</span><span class="p">[</span><span class="n">j</span><span class="p">])</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">tokens_b</span><span class="p">)</span> <span class="o">&gt;=</span> <span class="n">target_b_length</span><span class="p">:</span> <span class="k">break</span> <span class="c1"># We didn&#39;t actually use these segments so we &quot;put them back&quot; so</span> <span class="c1"># they don&#39;t go to waste.</span> <span class="n">num_unused_segments</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">current_chunk</span><span class="p">)</span> <span class="o">-</span> <span class="n">a_end</span> <span class="n">i</span> <span class="o">-=</span> <span class="n">num_unused_segments</span> <span class="k">else</span><span class="p">:</span> <span class="n">is_random_next</span> <span class="o">=</span> <span class="kc">False</span> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">a_end</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">current_chunk</span><span class="p">)):</span> <span class="n">tokens_b</span><span class="o">.</span><span class="n">extend</span><span class="p">(</span><span class="n">current_chunk</span><span class="p">[</span><span class="n">j</span><span class="p">])</span> <span class="n">input_ids</span> <span class="o">=</span> <span class="n">tokenizer</span><span class="o">.</span><span class="n">build_inputs_with_special_tokens</span><span class="p">(</span> <span class="n">tokens_a</span><span class="p">,</span> <span class="n">tokens_b</span> <span class="p">)</span> <span class="c1"># add token type ids, 0 for sentence a, 1 for sentence b</span> <span class="n">token_type_ids</span> <span class="o">=</span> <span class="n">tokenizer</span><span class="o">.</span><span class="n">create_token_type_ids_from_sequences</span><span class="p">(</span> <span class="n">tokens_a</span><span class="p">,</span> <span class="n">tokens_b</span> <span class="p">)</span> <span class="n">padded</span> <span class="o">=</span> <span class="n">tokenizer</span><span class="o">.</span><span class="n">pad</span><span class="p">(</span> <span class="p">{</span><span class="s2">&quot;input_ids&quot;</span><span class="p">:</span> <span class="n">input_ids</span><span class="p">,</span> <span class="s2">&quot;token_type_ids&quot;</span><span class="p">:</span> <span class="n">token_type_ids</span><span class="p">},</span> <span class="n">padding</span><span class="o">=</span><span class="s2">&quot;max_length&quot;</span><span class="p">,</span> <span class="n">max_length</span><span class="o">=</span><span class="n">MAX_LENGTH</span><span class="p">,</span> <span class="p">)</span> <span class="n">examples</span><span class="p">[</span><span class="s2">&quot;input_ids&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">padded</span><span class="p">[</span><span class="s2">&quot;input_ids&quot;</span><span class="p">])</span> <span class="n">examples</span><span class="p">[</span><span class="s2">&quot;token_type_ids&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">padded</span><span class="p">[</span><span class="s2">&quot;token_type_ids&quot;</span><span class="p">])</span> <span class="n">examples</span><span class="p">[</span><span class="s2">&quot;attention_mask&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">padded</span><span class="p">[</span><span class="s2">&quot;attention_mask&quot;</span><span class="p">])</span> <span class="n">examples</span><span class="p">[</span><span class="s2">&quot;next_sentence_label&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="mi">1</span> <span class="k">if</span> <span class="n">is_random_next</span> <span class="k">else</span> <span class="mi">0</span><span class="p">)</span> <span class="n">current_chunk</span> <span class="o">=</span> <span class="p">[]</span> <span class="n">current_length</span> <span class="o">=</span> <span class="mi">0</span> <span class="n">i</span> <span class="o">+=</span> <span class="mi">1</span> <span class="c1"># We delete all the un-necessary columns from our dataset</span> <span class="k">del</span> <span class="n">examples</span><span class="p">[</span><span class="s2">&quot;document&quot;</span><span class="p">]</span> <span class="k">del</span> <span class="n">examples</span><span class="p">[</span><span class="s2">&quot;sentences&quot;</span><span class="p">]</span> <span class="k">del</span> <span class="n">examples</span><span class="p">[</span><span class="s2">&quot;text&quot;</span><span class="p">]</span> <span class="k">del</span> <span class="n">examples</span><span class="p">[</span><span class="s2">&quot;tokenized_sentences&quot;</span><span class="p">]</span> <span class="k">return</span> <span class="n">examples</span> <span class="n">tokenized_dataset</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">map</span><span class="p">(</span> <span class="n">prepare_train_features</span><span class="p">,</span> <span class="n">batched</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">remove_columns</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;text&quot;</span><span class="p">],</span> <span class="n">num_proc</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Parameter &#39;function&#39;=&lt;function prepare_train_features at 0x7fd4a214cb90&gt; of the transform datasets.arrow_dataset.Dataset._map_single couldn&#39;t be hashed properly, a random hash was used instead. Make sure your transforms and parameters are serializable with pickle or dill for the dataset fingerprinting and caching to work. If you reuse this transform, the caching mechanism will consider it to be different from the previous calls and recompute everything. This warning is only showed once. Subsequent hashing failures won&#39;t be showed. 0%| | 0/5 [00:00&lt;?, ?ba/s] 0%| | 0/1 [00:00&lt;?, ?ba/s] 0%| | 0/1 [00:00&lt;?, ?ba/s] </code></pre></div> </div> <p>For MLM we are going to use the same preprocessing as before for our dataset with one additional step: we randomly mask some tokens (by replacing them by [MASK]) and the labels will be adjusted to only include the masked tokens (we don't have to predict the non-masked tokens). If you use a tokenizer you trained yourself, make sure the [MASK] token is among the special tokens you passed during training!</p> <p>To get the data ready for MLM, we simply use the <code>collator</code> called the <code>DataCollatorForLanguageModeling</code> provided by the 🤗 Transformers library on our dataset that is already ready for the NSP task. The <code>collator</code> expects certain parameters. We use the default ones from the original BERT paper in this notebook. The <code>return_tensors='tf'</code> ensures that we get <a href="https://www.tensorflow.org/api_docs/python/tf/Tensor"><code>tf.Tensor</code></a> objects back.</p> <div class="codehilite"><pre><span></span><code><span class="kn">from</span><span class="w"> </span><span class="nn">transformers</span><span class="w"> </span><span class="kn">import</span> <span class="n">DataCollatorForLanguageModeling</span> <span class="n">collater</span> <span class="o">=</span> <span class="n">DataCollatorForLanguageModeling</span><span class="p">(</span> <span class="n">tokenizer</span><span class="o">=</span><span class="n">tokenizer</span><span class="p">,</span> <span class="n">mlm</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">mlm_probability</span><span class="o">=</span><span class="n">MLM_PROB</span><span class="p">,</span> <span class="n">return_tensors</span><span class="o">=</span><span class="s2">&quot;tf&quot;</span> <span class="p">)</span> </code></pre></div> <p>Next we define our training set with which we train our model. Again, 🤗 Datasets provides us with the <code>to_tf_dataset</code> method which will help us integrate our dataset with the <code>collator</code> defined above. The method expects certain parameters:</p> <ul> <li><strong>columns</strong>: the columns which will serve as our independent variables</li> <li><strong>label_cols</strong>: the columns which will serve as our labels or dependant variables</li> <li><strong>batch_size</strong>: our batch size for training</li> <li><strong>shuffle</strong>: whether we want to shuffle our training dataset</li> <li><strong>collate_fn</strong>: our collator function</li> </ul> <div class="codehilite"><pre><span></span><code><span class="n">train</span> <span class="o">=</span> <span class="n">tokenized_dataset</span><span class="p">[</span><span class="s2">&quot;train&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">to_tf_dataset</span><span class="p">(</span> <span class="n">columns</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;input_ids&quot;</span><span class="p">,</span> <span class="s2">&quot;token_type_ids&quot;</span><span class="p">,</span> <span class="s2">&quot;attention_mask&quot;</span><span class="p">],</span> <span class="n">label_cols</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;labels&quot;</span><span class="p">,</span> <span class="s2">&quot;next_sentence_label&quot;</span><span class="p">],</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">TRAIN_BATCH_SIZE</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">collate_fn</span><span class="o">=</span><span class="n">collater</span><span class="p">,</span> <span class="p">)</span> <span class="n">validation</span> <span class="o">=</span> <span class="n">tokenized_dataset</span><span class="p">[</span><span class="s2">&quot;validation&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">to_tf_dataset</span><span class="p">(</span> <span class="n">columns</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;input_ids&quot;</span><span class="p">,</span> <span class="s2">&quot;token_type_ids&quot;</span><span class="p">,</span> <span class="s2">&quot;attention_mask&quot;</span><span class="p">],</span> <span class="n">label_cols</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;labels&quot;</span><span class="p">,</span> <span class="s2">&quot;next_sentence_label&quot;</span><span class="p">],</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">TRAIN_BATCH_SIZE</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">collate_fn</span><span class="o">=</span><span class="n">collater</span><span class="p">,</span> <span class="p">)</span> </code></pre></div> <hr /> <h2 id="defining-the-model">Defining the model</h2> <p>To define our model, first we need to define a config which will help us define certain parameters of our model architecture. This includes parameters like number of transformer layers, number of attention heads, hidden dimension, etc. For this notebook, we try to define the exact config defined in the original BERT paper.</p> <p>We can easily achieve this using the <code>BertConfig</code> class from the 🤗 Transformers library. The <code>from_pretrained()</code> method expects the name of a model. Here we define the simplest model with which we also trained our model, i.e., <code>bert-base-cased</code>.</p> <div class="codehilite"><pre><span></span><code><span class="kn">from</span><span class="w"> </span><span class="nn">transformers</span><span class="w"> </span><span class="kn">import</span> <span class="n">BertConfig</span> <span class="n">config</span> <span class="o">=</span> <span class="n">BertConfig</span><span class="o">.</span><span class="n">from_pretrained</span><span class="p">(</span><span class="n">MODEL_CHECKPOINT</span><span class="p">)</span> </code></pre></div> <p>For defining our model we use the <code>TFBertForPreTraining</code> class from the 🤗 Transformers library. This class internally handles everything starting from defining our model, to unpacking our inputs and calculating the loss. So we need not do anything ourselves except defining the model with the correct <code>config</code> we want!</p> <div class="codehilite"><pre><span></span><code><span class="kn">from</span><span class="w"> </span><span class="nn">transformers</span><span class="w"> </span><span class="kn">import</span> <span class="n">TFBertForPreTraining</span> <span class="n">model</span> <span class="o">=</span> <span class="n">TFBertForPreTraining</span><span class="p">(</span><span class="n">config</span><span class="p">)</span> </code></pre></div> <p>Now we define our optimizer and compile the model. The loss calculation is handled internally and so we need not worry about that!</p> <div class="codehilite"><pre><span></span><code><span class="n">optimizer</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">learning_rate</span><span class="o">=</span><span class="n">LEARNING_RATE</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">optimizer</span><span class="o">=</span><span class="n">optimizer</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>No loss specified in compile() - the model&#39;s internal loss computation will be used as the loss. Don&#39;t panic - this is a common way to train TensorFlow models in Transformers! To disable this behaviour please pass a loss argument, or explicitly pass `loss=None` if you do not want your model to compute a loss. </code></pre></div> </div> <p>Finally all steps are done and now we can start training our model!</p> <div class="codehilite"><pre><span></span><code><span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">train</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="n">validation</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="n">MAX_EPOCHS</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>483/483 [==============================] - 96s 141ms/step - loss: 8.3765 - val_loss: 8.5572 &lt;keras.callbacks.History at 0x7fd27c219790&gt; </code></pre></div> </div> <p>Our model has now been trained! We suggest to please train the model on the complete dataset for atleast 50 epochs for decent performance. The pretrained model now acts as a language model and is meant to be fine-tuned on a downstream task. Thus it can now be fine-tuned on any downstream task like Question Answering, Text Classification etc.!</p> <p>Now you can push this model to 🤗 Model Hub and also share it with with all your friends, family, favorite pets: they can all load it with the identifier <code>"your-username/the-name-you-picked"</code> so for instance:</p> <div class="codehilite"><pre><span></span><code><span class="n">model</span><span class="o">.</span><span class="n">push_to_hub</span><span class="p">(</span><span class="s2">&quot;pretrained-bert&quot;</span><span class="p">,</span> <span class="n">organization</span><span class="o">=</span><span class="s2">&quot;keras-io&quot;</span><span class="p">)</span> <span class="n">tokenizer</span><span class="o">.</span><span class="n">push_to_hub</span><span class="p">(</span><span class="s2">&quot;pretrained-bert&quot;</span><span class="p">,</span> <span class="n">organization</span><span class="o">=</span><span class="s2">&quot;keras-io&quot;</span><span class="p">)</span> </code></pre></div> <p>And after you push your model this is how you can load it in the future!</p> <div class="codehilite"><pre><span></span><code><span class="kn">from</span><span class="w"> </span><span class="nn">transformers</span><span class="w"> </span><span class="kn">import</span> <span class="n">TFBertForPreTraining</span> <span class="n">model</span> <span class="o">=</span> <span class="n">TFBertForPreTraining</span><span class="o">.</span><span class="n">from_pretrained</span><span class="p">(</span><span class="s2">&quot;your-username/my-awesome-model&quot;</span><span class="p">)</span> </code></pre></div> <p>or, since it's a pretrained model and you would generally use it for fine-tuning on a downstream task, you can also load it for some other task like:</p> <div class="codehilite"><pre><span></span><code><span class="kn">from</span><span class="w"> </span><span class="nn">transformers</span><span class="w"> </span><span class="kn">import</span> <span class="n">TFBertForSequenceClassification</span> <span class="n">model</span> <span class="o">=</span> <span class="n">TFBertForSequenceClassification</span><span class="o">.</span><span class="n">from_pretrained</span><span class="p">(</span><span class="s2">&quot;your-username/my-awesome-model&quot;</span><span class="p">)</span> </code></pre></div> <p>In this case, the pretraining head will be dropped and the model will just be initialized with the transformer layers. A new task-specific head will be added with random weights.</p> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#pretraining-bert-with-hugging-face-transformers'>Pretraining BERT with Hugging Face Transformers</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-3'> <a href='#bert-bidirectional-encoder-representations-from-transformers'>BERT (Bidirectional Encoder Representations from Transformers)</a> </div> <div class='k-outline-depth-3'> <a href='#masked-language-modeling-mlm'>Masked Language Modeling (MLM)</a> </div> <div class='k-outline-depth-3'> <a href='#next-sentence-prediction-nsp'>Next Sentence Prediction (NSP)</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#setup'>Setup</a> </div> <div class='k-outline-depth-3'> <a href='#installing-the-requirements'>Installing the requirements</a> </div> <div class='k-outline-depth-3'> <a href='#importing-the-necessary-libraries'>Importing the necessary libraries</a> </div> <div class='k-outline-depth-3'> <a href='#define-certain-variables'>Define certain variables</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#load-the-wikitext-dataset'>Load the WikiText dataset</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#training-a-new-tokenizer'>Training a new Tokenizer</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#data-preprocessing'>Data Pre-processing</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#defining-the-model'>Defining the model</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>

Pages: 1 2 3 4 5 6 7 8 9 10