CINXE.COM
Parameter-efficient fine-tuning of GPT-2 with LoRA
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/nlp/parameter_efficient_finetuning_of_gpt2_with_lora/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Parameter-efficient fine-tuning of GPT-2 with LoRA"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Parameter-efficient fine-tuning of GPT-2 with LoRA"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Parameter-efficient fine-tuning of GPT-2 with LoRA</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink active" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink2" href="/examples/nlp/text_classification_from_scratch/">Text classification from scratch</a> <a class="nav-sublink2" href="/examples/nlp/active_learning_review_classification/">Review Classification using Active Learning</a> <a class="nav-sublink2" href="/examples/nlp/fnet_classification_with_keras_hub/">Text Classification using FNet</a> <a class="nav-sublink2" href="/examples/nlp/multi_label_classification/">Large-scale multi-label text classification</a> <a class="nav-sublink2" href="/examples/nlp/text_classification_with_transformer/">Text classification with Transformer</a> <a class="nav-sublink2" href="/examples/nlp/text_classification_with_switch_transformer/">Text classification with Switch Transformer</a> <a class="nav-sublink2" href="/examples/nlp/tweet-classification-using-tfdf/">Text classification using Decision Forests and pretrained embeddings</a> <a class="nav-sublink2" href="/examples/nlp/pretrained_word_embeddings/">Using pre-trained word embeddings</a> <a class="nav-sublink2" href="/examples/nlp/bidirectional_lstm_imdb/">Bidirectional LSTM on IMDB</a> <a class="nav-sublink2" href="/examples/nlp/data_parallel_training_with_keras_hub/">Data Parallel Training with KerasHub and tf.distribute</a> <a class="nav-sublink2" href="/examples/nlp/neural_machine_translation_with_keras_hub/">English-to-Spanish translation with KerasHub</a> <a class="nav-sublink2" href="/examples/nlp/neural_machine_translation_with_transformer/">English-to-Spanish translation with a sequence-to-sequence Transformer</a> <a class="nav-sublink2" href="/examples/nlp/lstm_seq2seq/">Character-level recurrent sequence-to-sequence model</a> <a class="nav-sublink2" href="/examples/nlp/multimodal_entailment/">Multimodal entailment</a> <a class="nav-sublink2" href="/examples/nlp/ner_transformers/">Named Entity Recognition using Transformers</a> <a class="nav-sublink2" href="/examples/nlp/text_extraction_with_bert/">Text Extraction with BERT</a> <a class="nav-sublink2" href="/examples/nlp/addition_rnn/">Sequence to sequence learning for performing number addition</a> <a class="nav-sublink2" href="/examples/nlp/semantic_similarity_with_keras_hub/">Semantic Similarity with KerasHub</a> <a class="nav-sublink2" href="/examples/nlp/semantic_similarity_with_bert/">Semantic Similarity with BERT</a> <a class="nav-sublink2" href="/examples/nlp/sentence_embeddings_with_sbert/">Sentence embeddings using Siamese RoBERTa-networks</a> <a class="nav-sublink2" href="/examples/nlp/masked_language_modeling/">End-to-end Masked Language Modeling with BERT</a> <a class="nav-sublink2" href="/examples/nlp/abstractive_summarization_with_bart/">Abstractive Text Summarization with BART</a> <a class="nav-sublink2" href="/examples/nlp/pretraining_BERT/">Pretraining BERT with Hugging Face Transformers</a> <a class="nav-sublink2 active" href="/examples/nlp/parameter_efficient_finetuning_of_gpt2_with_lora/">Parameter-efficient fine-tuning of GPT-2 with LoRA</a> <a class="nav-sublink2" href="/examples/nlp/mlm_training_tpus/">Training a language model from scratch with 🤗 Transformers and TPUs</a> <a class="nav-sublink2" href="/examples/nlp/multiple_choice_task_with_transfer_learning/">MultipleChoice Task with Transfer Learning</a> <a class="nav-sublink2" href="/examples/nlp/question_answering/">Question Answering with Hugging Face Transformers</a> <a class="nav-sublink2" href="/examples/nlp/t5_hf_summarization/">Abstractive Summarization with Hugging Face Transformers</a> <a class="nav-sublink" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparameter Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> <a class="nav-link" href="/keras_cv/" role="tab" aria-selected="">KerasCV: Computer Vision Workflows</a> <a class="nav-link" href="/keras_nlp/" role="tab" aria-selected="">KerasNLP: Natural Language Workflows</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/nlp/'>Natural Language Processing</a> / Parameter-efficient fine-tuning of GPT-2 with LoRA </div> <div class='k-content'> <h1 id="parameterefficient-finetuning-of-gpt2-with-lora">Parameter-efficient fine-tuning of GPT-2 with LoRA</h1> <p><strong>Author:</strong> <a href="https://github.com/abheesht17/">Abheesht Sharma</a>, <a href="https://github.com/mattdangerw/">Matthew Watson</a><br> <strong>Date created:</strong> 2023/05/27<br> <strong>Last modified:</strong> 2023/05/27<br> <strong>Description:</strong> Use KerasHub to fine-tune a GPT-2 LLM with LoRA.</p> <div class='example_version_banner keras_3'>ⓘ This example uses Keras 3</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/nlp/ipynb/parameter_efficient_finetuning_of_gpt2_with_lora.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/examples/nlp/parameter_efficient_finetuning_of_gpt2_with_lora.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="introduction">Introduction</h2> <p>Large Language Models (LLMs) have been shown to be effective at a variety of NLP tasks. An LLM is first pre-trained on a large corpus of text in a self-supervised fashion. Pre-training helps LLMs learn general-purpose knowledge, such as statistical relationships between words. An LLM can then be fine-tuned on a downstream task of interest (such as sentiment analysis).</p> <p>However, LLMs are extremely large in size, and we don't need to train all the parameters in the model while fine-tuning, especially because datasets on which the model is fine-tuned are relatively small. Another way of saying this is that LLMs are over-parametrized for fine-tuning. This is where <a href="https://arxiv.org/abs/2106.09685">Low-Rank Adaptation (LoRA)</a> comes in; it significantly reduces the number of trainable parameters. This results in a decrease in training time and GPU memory usage, while maintaining the quality of the outputs.</p> <p>In this example, we will explain LoRA in technical terms, show how the technical explanation translates to code, hack KerasHub's <a href="https://keras.io/api/keras_hub/models/gpt2/">GPT-2 model</a> and fine-tune it on the next token prediction task using LoRA. We will compare LoRA GPT-2 with a fully fine-tuned GPT-2 in terms of the quality of the generated text, training time and GPU memory usage.</p> <p>Note: This example runs on the TensorFlow backend purely for the <a href="https://www.tensorflow.org/api_docs/python/tf/config/experimental/get_memory_info"><code>tf.config.experimental.get_memory_info</code></a> API to easily plot memory usage. Outside of the memory usage callback, this example will run on <code>jax</code> and <code>torch</code> backends.</p> <hr /> <h2 id="setup">Setup</h2> <p>Before we start implementing the pipeline, let's install and import all the libraries we need. We'll be using the KerasHub library.</p> <p>Secondly, let's enable mixed precision training. This will help us reduce the training time.</p> <div class="codehilite"><pre><span></span><code><span class="err">!</span><span class="n">pip</span> <span class="n">install</span> <span class="o">-</span><span class="n">q</span> <span class="o">--</span><span class="n">upgrade</span> <span class="n">keras</span><span class="o">-</span><span class="n">hub</span> <span class="err">!</span><span class="n">pip</span> <span class="n">install</span> <span class="o">-</span><span class="n">q</span> <span class="o">--</span><span class="n">upgrade</span> <span class="n">keras</span> <span class="c1"># Upgrade to Keras 3.</span> </code></pre></div> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">os</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s2">"KERAS_BACKEND"</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"tensorflow"</span> <span class="kn">import</span> <span class="nn">keras_hub</span> <span class="kn">import</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span> <span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="nn">tf</span> <span class="kn">import</span> <span class="nn">tensorflow_datasets</span> <span class="k">as</span> <span class="nn">tfds</span> <span class="kn">import</span> <span class="nn">time</span> <span class="n">keras</span><span class="o">.</span><span class="n">mixed_precision</span><span class="o">.</span><span class="n">set_global_policy</span><span class="p">(</span><span class="s2">"mixed_float16"</span><span class="p">)</span> </code></pre></div> <p>Let's also define our hyperparameters.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># General hyperparameters</span> <span class="n">BATCH_SIZE</span> <span class="o">=</span> <span class="mi">32</span> <span class="n">NUM_BATCHES</span> <span class="o">=</span> <span class="mi">500</span> <span class="n">EPOCHS</span> <span class="o">=</span> <span class="mi">1</span> <span class="c1"># Can be set to a higher value for better results</span> <span class="n">MAX_SEQUENCE_LENGTH</span> <span class="o">=</span> <span class="mi">128</span> <span class="n">MAX_GENERATION_LENGTH</span> <span class="o">=</span> <span class="mi">200</span> <span class="n">GPT2_PRESET</span> <span class="o">=</span> <span class="s2">"gpt2_base_en"</span> <span class="c1"># LoRA-specific hyperparameters</span> <span class="n">RANK</span> <span class="o">=</span> <span class="mi">4</span> <span class="n">ALPHA</span> <span class="o">=</span> <span class="mf">32.0</span> </code></pre></div> <hr /> <h2 id="dataset">Dataset</h2> <p>Let's load a Reddit dataset. We will fine-tune both the GPT-2 model and the LoRA GPT-2 model on a subset of this dataset. The aim is to produce text similar in style to Reddit posts.</p> <div class="codehilite"><pre><span></span><code><span class="n">reddit_ds</span> <span class="o">=</span> <span class="n">tfds</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s2">"reddit_tifu"</span><span class="p">,</span> <span class="n">split</span><span class="o">=</span><span class="s2">"train"</span><span class="p">,</span> <span class="n">as_supervised</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> </code></pre></div> <p>The dataset has two fields: <code>document</code> and <code>title</code>.</p> <div class="codehilite"><pre><span></span><code><span class="k">for</span> <span class="n">document</span><span class="p">,</span> <span class="n">title</span> <span class="ow">in</span> <span class="n">reddit_ds</span><span class="p">:</span> <span class="nb">print</span><span class="p">(</span><span class="n">document</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span> <span class="nb">print</span><span class="p">(</span><span class="n">title</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span> <span class="k">break</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>b"me and a friend decided to go to the beach last sunday. we loaded up and headed out. we were about half way there when i decided that i was not leaving till i had seafood. \n\nnow i'm not talking about red lobster. no friends i'm talking about a low country boil. i found the restaurant and got directions. i don't know if any of you have heard about the crab shack on tybee island but let me tell you it's worth it. \n\nwe arrived and was seated quickly. we decided to get a seafood sampler for two and split it. the waitress bought it out on separate platters for us. the amount of food was staggering. two types of crab, shrimp, mussels, crawfish, andouille sausage, red potatoes, and corn on the cob. i managed to finish it and some of my friends crawfish and mussels. it was a day to be a fat ass. we finished paid for our food and headed to the beach. \n\nfunny thing about seafood. it runs through me faster than a kenyan \n\nwe arrived and walked around a bit. it was about 45min since we arrived at the beach when i felt a rumble from the depths of my stomach. i ignored it i didn't want my stomach to ruin our fun. i pushed down the feeling and continued. about 15min later the feeling was back and stronger than before. again i ignored it and continued. 5min later it felt like a nuclear reactor had just exploded in my stomach. i started running. i yelled to my friend to hurry the fuck up. \n\nrunning in sand is extremely hard if you did not know this. we got in his car and i yelled at him to floor it. my stomach was screaming and if he didn't hurry i was gonna have this baby in his car and it wasn't gonna be pretty. after a few red lights and me screaming like a woman in labor we made it to the store. \n\ni practically tore his car door open and ran inside. i ran to the bathroom opened the door and barely got my pants down before the dam burst and a flood of shit poured from my ass. \n\ni finished up when i felt something wet on my ass. i rubbed it thinking it was back splash. no, mass was covered in the after math of me abusing the toilet. i grabbed all the paper towels i could and gave my self a whores bath right there. \n\ni sprayed the bathroom down with the air freshener and left. an elderly lady walked in quickly and closed the door. i was just about to walk away when i heard gag. instead of walking i ran. i got to the car and told him to get the hell out of there." b'liking seafood' </code></pre></div> </div> <p>We'll now batch the dataset and retain only the <code>document</code> field because we are fine-tuning the model on the next word prediction task. Take a subset of the dataset for the purpose of this example.</p> <div class="codehilite"><pre><span></span><code><span class="n">train_ds</span> <span class="o">=</span> <span class="p">(</span> <span class="n">reddit_ds</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">document</span><span class="p">,</span> <span class="n">_</span><span class="p">:</span> <span class="n">document</span><span class="p">)</span> <span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="n">BATCH_SIZE</span><span class="p">)</span> <span class="o">.</span><span class="n">cache</span><span class="p">()</span> <span class="o">.</span><span class="n">prefetch</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">AUTOTUNE</span><span class="p">)</span> <span class="p">)</span> <span class="n">train_ds</span> <span class="o">=</span> <span class="n">train_ds</span><span class="o">.</span><span class="n">take</span><span class="p">(</span><span class="n">NUM_BATCHES</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="helper-functions">Helper functions</h2> <p>Before we begin fine-tuning the models, let's define a few helper functions and classes.</p> <h3 id="callback-for-tracking-gpu-memory-usage">Callback for tracking GPU memory usage</h3> <p>We'll define a custom callback function which tracks GPU memory usage. The callback function uses TensorFlow's <a href="https://www.tensorflow.org/api_docs/python/tf/config/experimental/get_memory_info"><code>tf.config.experimental.get_memory_info</code></a> API.</p> <p>Here, we assume that we are using a single GPU, <code>GPU:0</code>.</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">GPUMemoryCallback</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">callbacks</span><span class="o">.</span><span class="n">Callback</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> <span class="bp">self</span><span class="p">,</span> <span class="n">target_batches</span><span class="p">,</span> <span class="n">print_stats</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">,</span> <span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">target_batches</span> <span class="o">=</span> <span class="n">target_batches</span> <span class="bp">self</span><span class="o">.</span><span class="n">print_stats</span> <span class="o">=</span> <span class="n">print_stats</span> <span class="bp">self</span><span class="o">.</span><span class="n">memory_usage</span> <span class="o">=</span> <span class="p">[]</span> <span class="bp">self</span><span class="o">.</span><span class="n">labels</span> <span class="o">=</span> <span class="p">[]</span> <span class="k">def</span> <span class="nf">_compute_memory_usage</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="n">memory_stats</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">experimental</span><span class="o">.</span><span class="n">get_memory_info</span><span class="p">(</span><span class="s2">"GPU:0"</span><span class="p">)</span> <span class="c1"># Convert bytes to GB and store in list.</span> <span class="n">peak_usage</span> <span class="o">=</span> <span class="nb">round</span><span class="p">(</span><span class="n">memory_stats</span><span class="p">[</span><span class="s2">"peak"</span><span class="p">]</span> <span class="o">/</span> <span class="p">(</span><span class="mi">2</span><span class="o">**</span><span class="mi">30</span><span class="p">),</span> <span class="mi">3</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">memory_usage</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">peak_usage</span><span class="p">)</span> <span class="k">def</span> <span class="nf">on_epoch_begin</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">epoch</span><span class="p">,</span> <span class="n">logs</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="bp">self</span><span class="o">.</span><span class="n">_compute_memory_usage</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">labels</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="sa">f</span><span class="s2">"epoch </span><span class="si">{</span><span class="n">epoch</span><span class="si">}</span><span class="s2"> start"</span><span class="p">)</span> <span class="k">def</span> <span class="nf">on_train_batch_begin</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">,</span> <span class="n">logs</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="k">if</span> <span class="n">batch</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">target_batches</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_compute_memory_usage</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">labels</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="sa">f</span><span class="s2">"batch </span><span class="si">{</span><span class="n">batch</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> <span class="k">def</span> <span class="nf">on_epoch_end</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">epoch</span><span class="p">,</span> <span class="n">logs</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="bp">self</span><span class="o">.</span><span class="n">_compute_memory_usage</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">labels</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="sa">f</span><span class="s2">"epoch </span><span class="si">{</span><span class="n">epoch</span><span class="si">}</span><span class="s2"> end"</span><span class="p">)</span> </code></pre></div> <h3 id="function-for-text-generation">Function for text generation</h3> <p>Here is a helper function to generate text.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">generate_text</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">input_text</span><span class="p">,</span> <span class="n">max_length</span><span class="o">=</span><span class="mi">200</span><span class="p">):</span> <span class="n">start</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span> <span class="n">output</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">generate</span><span class="p">(</span><span class="n">input_text</span><span class="p">,</span> <span class="n">max_length</span><span class="o">=</span><span class="n">max_length</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"</span><span class="se">\n</span><span class="s2">Output:"</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="n">output</span><span class="p">)</span> <span class="n">end</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Total Time Elapsed: </span><span class="si">{</span><span class="n">end</span><span class="w"> </span><span class="o">-</span><span class="w"> </span><span class="n">start</span><span class="si">:</span><span class="s2">.2f</span><span class="si">}</span><span class="s2">s"</span><span class="p">)</span> </code></pre></div> <h3 id="define-optimizer-and-loss">Define optimizer and loss</h3> <p>We will use AdamW optimizer and cross-entropy loss for training both models.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">get_optimizer_and_loss</span><span class="p">():</span> <span class="n">optimizer</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">AdamW</span><span class="p">(</span> <span class="n">learning_rate</span><span class="o">=</span><span class="mf">5e-5</span><span class="p">,</span> <span class="n">weight_decay</span><span class="o">=</span><span class="mf">0.01</span><span class="p">,</span> <span class="n">epsilon</span><span class="o">=</span><span class="mf">1e-6</span><span class="p">,</span> <span class="n">global_clipnorm</span><span class="o">=</span><span class="mf">1.0</span><span class="p">,</span> <span class="c1"># Gradient clipping.</span> <span class="p">)</span> <span class="c1"># Exclude layernorm and bias terms from weight decay.</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">exclude_from_weight_decay</span><span class="p">(</span><span class="n">var_names</span><span class="o">=</span><span class="p">[</span><span class="s2">"bias"</span><span class="p">])</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">exclude_from_weight_decay</span><span class="p">(</span><span class="n">var_names</span><span class="o">=</span><span class="p">[</span><span class="s2">"gamma"</span><span class="p">])</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">exclude_from_weight_decay</span><span class="p">(</span><span class="n">var_names</span><span class="o">=</span><span class="p">[</span><span class="s2">"beta"</span><span class="p">])</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">SparseCategoricalCrossentropy</span><span class="p">(</span><span class="n">from_logits</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="k">return</span> <span class="n">optimizer</span><span class="p">,</span> <span class="n">loss</span> </code></pre></div> <hr /> <h2 id="finetune-gpt2">Fine-tune GPT-2</h2> <p>Let's load the model and preprocessor first. We use a sequence length of 128 instead of 1024 (which is the default sequence length). This will limit our ability to predict long sequences, but will allow us to run this example quickly on Colab.</p> <div class="codehilite"><pre><span></span><code><span class="n">preprocessor</span> <span class="o">=</span> <span class="n">keras_hub</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">GPT2CausalLMPreprocessor</span><span class="o">.</span><span class="n">from_preset</span><span class="p">(</span> <span class="s2">"gpt2_base_en"</span><span class="p">,</span> <span class="n">sequence_length</span><span class="o">=</span><span class="n">MAX_SEQUENCE_LENGTH</span><span class="p">,</span> <span class="p">)</span> <span class="n">gpt2_lm</span> <span class="o">=</span> <span class="n">keras_hub</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">GPT2CausalLM</span><span class="o">.</span><span class="n">from_preset</span><span class="p">(</span> <span class="s2">"gpt2_base_en"</span><span class="p">,</span> <span class="n">preprocessor</span><span class="o">=</span><span class="n">preprocessor</span> <span class="p">)</span> <span class="n">gpt2_lm</span><span class="o">.</span><span class="n">summary</span><span class="p">()</span> </code></pre></div> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold">Preprocessor: "gpt2_causal_lm_preprocessor"</span> </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃<span style="font-weight: bold"> Tokenizer (type) </span>┃<span style="font-weight: bold"> Vocab # </span>┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ gpt2_tokenizer (<span style="color: #0087ff; text-decoration-color: #0087ff">GPT2Tokenizer</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">50,257</span> │ └────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘ </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold">Model: "gpt2_causal_lm"</span> </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃<span style="font-weight: bold"> Layer (type) </span>┃<span style="font-weight: bold"> Output Shape </span>┃<span style="font-weight: bold"> Param # </span>┃<span style="font-weight: bold"> Connected to </span>┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ padding_mask (<span style="color: #0087ff; text-decoration-color: #0087ff">InputLayer</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ - │ ├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤ │ token_ids (<span style="color: #0087ff; text-decoration-color: #0087ff">InputLayer</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ - │ ├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤ │ gpt2_backbone (<span style="color: #0087ff; text-decoration-color: #0087ff">GPT2Backbone</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">768</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">124,439,808</span> │ padding_mask[<span style="color: #00af00; text-decoration-color: #00af00">0</span>][<span style="color: #00af00; text-decoration-color: #00af00">0</span>], │ │ │ │ │ token_ids[<span style="color: #00af00; text-decoration-color: #00af00">0</span>][<span style="color: #00af00; text-decoration-color: #00af00">0</span>] │ ├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤ │ token_embedding │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">50257</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">38,597,376</span> │ gpt2_backbone[<span style="color: #00af00; text-decoration-color: #00af00">0</span>][<span style="color: #00af00; text-decoration-color: #00af00">0</span>] │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">ReversibleEmbedding</span>) │ │ │ │ └───────────────────────────────┴───────────────────────────┴─────────────┴────────────────────────────────┘ </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Total params: </span><span style="color: #00af00; text-decoration-color: #00af00">124,439,808</span> (474.70 MB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">124,439,808</span> (474.70 MB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Non-trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">0</span> (0.00 B) </pre> <p>Initialize the GPU memory tracker callback object, and compile the model. We use the Adam optimizer with a linearly decaying learning rate.</p> <div class="codehilite"><pre><span></span><code><span class="n">gpu_memory_callback</span> <span class="o">=</span> <span class="n">GPUMemoryCallback</span><span class="p">(</span> <span class="n">target_batches</span><span class="o">=</span><span class="p">[</span><span class="mi">5</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">25</span><span class="p">,</span> <span class="mi">50</span><span class="p">,</span> <span class="mi">100</span><span class="p">,</span> <span class="mi">150</span><span class="p">,</span> <span class="mi">200</span><span class="p">,</span> <span class="mi">300</span><span class="p">,</span> <span class="mi">400</span><span class="p">,</span> <span class="mi">500</span><span class="p">],</span> <span class="n">print_stats</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="p">)</span> <span class="n">optimizer</span><span class="p">,</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">get_optimizer_and_loss</span><span class="p">()</span> <span class="n">gpt2_lm</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">optimizer</span><span class="p">,</span> <span class="n">loss</span><span class="o">=</span><span class="n">loss</span><span class="p">,</span> <span class="n">weighted_metrics</span><span class="o">=</span><span class="p">[</span><span class="s2">"accuracy"</span><span class="p">],</span> <span class="p">)</span> </code></pre></div> <p>We are all set to train the model!</p> <div class="codehilite"><pre><span></span><code><span class="n">gpt2_lm</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">train_ds</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="n">EPOCHS</span><span class="p">,</span> <span class="n">callbacks</span><span class="o">=</span><span class="p">[</span><span class="n">gpu_memory_callback</span><span class="p">])</span> <span class="n">gpt2_lm_memory_usage</span> <span class="o">=</span> <span class="n">gpu_memory_callback</span><span class="o">.</span><span class="n">memory_usage</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1701128462.076856 38706 device_compiler.h:186] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process. W0000 00:00:1701128462.146837 38706 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update 500/500 ━━━━━━━━━━━━━━━━━━━━ 114s 128ms/step - accuracy: 0.3183 - loss: 3.3682 </code></pre></div> </div> <p>As a final step, let's generate some text. We will harness the power of XLA. The first call to <code>generate()</code> will be slow because of XLA compilation, but subsequent calls will be super-fast. :)</p> <div class="codehilite"><pre><span></span><code><span class="n">generate_text</span><span class="p">(</span><span class="n">gpt2_lm</span><span class="p">,</span> <span class="s2">"I like basketball"</span><span class="p">,</span> <span class="n">max_length</span><span class="o">=</span><span class="n">MAX_GENERATION_LENGTH</span><span class="p">)</span> <span class="n">generate_text</span><span class="p">(</span><span class="n">gpt2_lm</span><span class="p">,</span> <span class="s2">"That Italian restaurant is"</span><span class="p">,</span> <span class="n">max_length</span><span class="o">=</span><span class="n">MAX_GENERATION_LENGTH</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Output: I like basketball, but this one actually happened a few months ago. </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>i was on my way to a party in the city when i noticed a group of guys were playing basketball. one of my friends, a guy named "jenny," was playing. jenny's mom, a very nice girl, was sitting on her couch. </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>jenny and jenny were sitting in a circle around her, and i started to play some of my favorite basketball games. i got to the end of the circle and jenny started to run. i didn't know how jenny was doing. she ran, but it Total Time Elapsed: 6.66s </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Output: That Italian restaurant is a bit of a mystery, because the place is closed. so i was at my friends house and i went to grab some food, so i got the usual pizza and some chicken, but it wasn't really the pizza, so i just grabbed my friend's pizza. i had a lot of chicken, but i was hungry, so i decided to grab a few of the other pizza's that were already in there. </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>i was eating the pizza with some friends and i was eating the pizza and then i got a knock on the door. </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>the guy in front of me is Total Time Elapsed: 0.22s </code></pre></div> </div> <hr /> <h2 id="lora-gpt2">LoRA GPT-2</h2> <p>In this section, we discuss the technical details of LoRA, build a LoRA GPT-2 model, fine-tune it and generate text.</p> <h3 id="what-exactly-is-lora">What exactly is LoRA?</h3> <p>LoRA is a parameter-efficient fine-tuning technique for LLMs. It freezes the weights of the LLM, and injects trainable rank-decomposition matrices. Let's understand this more clearly.</p> <p>Assume we have an <code>n x n</code> pre-trained dense layer (or weight matrix), <code>W0</code>. We initialize two dense layers, <code>A</code> and <code>B</code>, of shapes <code>n x rank</code>, and <code>rank x n</code>, respectively. <code>rank</code> is much smaller than <code>n</code>. In the paper, values between 1 and 4 are shown to work well.</p> <h4 id="lora-equation">LoRA equation</h4> <p>The original equation is <code>output = W0x + b0</code>, where <code>x</code> is the input, <code>W0</code> and <code>b0</code> are the weight matrix and bias terms of the original dense layer (frozen). The LoRA equation is: <code>output = W0x + b0 + BAx</code>, where <code>A</code> and <code>B</code> are the rank-decomposition matrices.</p> <p>LoRA is based on the idea that updates to the weights of the pre-trained language model have a low "intrinsic rank" since pre-trained language models are over-parametrized. Predictive performance of full fine-tuning can be replicated even by constraining <code>W0</code>'s updates to low-rank decomposition matrices.</p> <p align="center"> <img src="https://i.imgur.com/f4TFqMi.png" alt="lora_diagram" height="250"/> </p> <p><br></p> <h4 id="number-of-trainable-parameters">Number of trainable parameters</h4> <p>Let's do some quick math. Suppose <code>n</code> is 768, and <code>rank</code> is 4. <code>W0</code> has <code>768 x 768 = 589,824</code> parameters, whereas the LoRA layers, <code>A</code> and <code>B</code> together have <code>768 x 4 + 4 x 768 = 6,144</code> parameters. So, for the dense layer, we go from <code>589,824</code> trainable parameters to <code>6,144</code> trainable parameters!</p> <h4 id="why-does-lora-reduce-memory-footprint">Why does LoRA reduce memory footprint?</h4> <p>Even though the total number of parameters increase (since we are adding LoRA layers), the memory footprint reduces, because the number of trainable parameters reduces. Let's dive deeper into this.</p> <p>The memory usage of a model can be split into four parts:</p> <ul> <li>Model memory: This is the memory required to store the model weights. This will be slightly higher for LoRA than GPT-2.</li> <li>Forward pass memory: This mostly depends on batch size, sequence length, etc. We keep this constant for both models for a fair comparison.</li> <li>Backward pass memory: This is the memory required to store the gradients. Note that the gradients are computed only for the trainable parameters.</li> <li>Optimizer memory: This is the memory required to store the optimizer state. For example, the Adam optimizer stores the "1st moment vectors" and "2nd moment vectors" for the trainable parameters.</li> </ul> <p>Since, with LoRA, there is a huge reduction in the number of trainable parameters, the optimizer memory and the memory required to store the gradients for LoRA is much less than GPT-2. This is where most of the memory savings happen.</p> <h4 id="why-is-lora-so-popular">Why is LoRA so popular?</h4> <ul> <li>Reduces GPU memory usage;</li> <li>Faster training; and</li> <li>No additional inference latency.</li> </ul> <h3 id="create-lora-layer">Create LoRA layer</h3> <p>According to the technical description above, let's create a LoRA layer. In a transformer model, the LoRA layer is created and injected for the query and value projection matrices. In <a href="/api/layers/attention_layers/multi_head_attention#multiheadattention-class"><code>keras.layers.MultiHeadAttention</code></a>, the query/value projection layers are <a href="/api/layers/core_layers/einsum_dense#einsumdense-class"><code>keras.layers.EinsumDense</code></a> layers.</p> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">math</span> <span class="k">class</span> <span class="nc">LoraLayer</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> <span class="bp">self</span><span class="p">,</span> <span class="n">original_layer</span><span class="p">,</span> <span class="n">rank</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="n">trainable</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">,</span> <span class="p">):</span> <span class="c1"># We want to keep the name of this layer the same as the original</span> <span class="c1"># dense layer.</span> <span class="n">original_layer_config</span> <span class="o">=</span> <span class="n">original_layer</span><span class="o">.</span><span class="n">get_config</span><span class="p">()</span> <span class="n">name</span> <span class="o">=</span> <span class="n">original_layer_config</span><span class="p">[</span><span class="s2">"name"</span><span class="p">]</span> <span class="n">kwargs</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s2">"name"</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">,</span> <span class="n">trainable</span><span class="o">=</span><span class="n">trainable</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">rank</span> <span class="o">=</span> <span class="n">rank</span> <span class="bp">self</span><span class="o">.</span><span class="n">alpha</span> <span class="o">=</span> <span class="n">alpha</span> <span class="bp">self</span><span class="o">.</span><span class="n">_scale</span> <span class="o">=</span> <span class="n">alpha</span> <span class="o">/</span> <span class="n">rank</span> <span class="bp">self</span><span class="o">.</span><span class="n">_num_heads</span> <span class="o">=</span> <span class="n">original_layer_config</span><span class="p">[</span><span class="s2">"output_shape"</span><span class="p">][</span><span class="o">-</span><span class="mi">2</span><span class="p">]</span> <span class="bp">self</span><span class="o">.</span><span class="n">_hidden_dim</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_num_heads</span> <span class="o">*</span> <span class="n">original_layer_config</span><span class="p">[</span><span class="s2">"output_shape"</span><span class="p">][</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="c1"># Layers.</span> <span class="c1"># Original dense layer.</span> <span class="bp">self</span><span class="o">.</span><span class="n">original_layer</span> <span class="o">=</span> <span class="n">original_layer</span> <span class="c1"># No matter whether we are training the model or are in inference mode,</span> <span class="c1"># this layer should be frozen.</span> <span class="bp">self</span><span class="o">.</span><span class="n">original_layer</span><span class="o">.</span><span class="n">trainable</span> <span class="o">=</span> <span class="kc">False</span> <span class="c1"># LoRA dense layers.</span> <span class="bp">self</span><span class="o">.</span><span class="n">A</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span> <span class="n">units</span><span class="o">=</span><span class="n">rank</span><span class="p">,</span> <span class="n">use_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="c1"># Note: the original paper mentions that normal distribution was</span> <span class="c1"># used for initialization. However, the official LoRA implementation</span> <span class="c1"># uses "Kaiming/He Initialization".</span> <span class="n">kernel_initializer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">initializers</span><span class="o">.</span><span class="n">VarianceScaling</span><span class="p">(</span> <span class="n">scale</span><span class="o">=</span><span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mi">5</span><span class="p">),</span> <span class="n">mode</span><span class="o">=</span><span class="s2">"fan_in"</span><span class="p">,</span> <span class="n">distribution</span><span class="o">=</span><span class="s2">"uniform"</span> <span class="p">),</span> <span class="n">trainable</span><span class="o">=</span><span class="n">trainable</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="sa">f</span><span class="s2">"lora_A"</span><span class="p">,</span> <span class="p">)</span> <span class="c1"># B has the same `equation` and `output_shape` as the original layer.</span> <span class="c1"># `equation = abc,cde->abde`, where `a`: batch size, `b`: sequence</span> <span class="c1"># length, `c`: `hidden_dim`, `d`: `num_heads`,</span> <span class="c1"># `e`: `hidden_dim//num_heads`. The only difference is that in layer `B`,</span> <span class="c1"># `c` represents `rank`.</span> <span class="bp">self</span><span class="o">.</span><span class="n">B</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">EinsumDense</span><span class="p">(</span> <span class="n">equation</span><span class="o">=</span><span class="n">original_layer_config</span><span class="p">[</span><span class="s2">"equation"</span><span class="p">],</span> <span class="n">output_shape</span><span class="o">=</span><span class="n">original_layer_config</span><span class="p">[</span><span class="s2">"output_shape"</span><span class="p">],</span> <span class="n">kernel_initializer</span><span class="o">=</span><span class="s2">"zeros"</span><span class="p">,</span> <span class="n">trainable</span><span class="o">=</span><span class="n">trainable</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="sa">f</span><span class="s2">"lora_B"</span><span class="p">,</span> <span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="n">original_output</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">original_layer</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">trainable</span><span class="p">:</span> <span class="c1"># If we are fine-tuning the model, we will add LoRA layers' output</span> <span class="c1"># to the original layer's output.</span> <span class="n">lora_output</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">B</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">A</span><span class="p">(</span><span class="n">inputs</span><span class="p">))</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">_scale</span> <span class="k">return</span> <span class="n">original_output</span> <span class="o">+</span> <span class="n">lora_output</span> <span class="c1"># If we are in inference mode, we "merge" the LoRA layers' weights into</span> <span class="c1"># the original layer's weights - more on this in the text generation</span> <span class="c1"># section!</span> <span class="k">return</span> <span class="n">original_output</span> </code></pre></div> <h3 id="inject-lora-layer-into-the-model">Inject LoRA layer into the model</h3> <p>We will now hack the original GPT-2 model and inject LoRA layers into it. Let's do a couple of things before doing that:</p> <ul> <li>Delete previous model;</li> <li>Reset "peak" GPU memory usage using <a href="https://www.tensorflow.org/api_docs/python/tf/config/experimental/reset_memory_stats"><code>tf.config.experimental.reset_memory_stats</code></a>;</li> <li>Load a new GPT-2 model.</li> </ul> <div class="codehilite"><pre><span></span><code><span class="k">del</span> <span class="n">gpt2_lm</span> <span class="k">del</span> <span class="n">optimizer</span> <span class="k">del</span> <span class="n">loss</span> <span class="c1"># This resets "peak" memory usage to "current" memory usage.</span> <span class="n">tf</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">experimental</span><span class="o">.</span><span class="n">reset_memory_stats</span><span class="p">(</span><span class="s2">"GPU:0"</span><span class="p">)</span> <span class="c1"># Load the original model.</span> <span class="n">preprocessor</span> <span class="o">=</span> <span class="n">keras_hub</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">GPT2CausalLMPreprocessor</span><span class="o">.</span><span class="n">from_preset</span><span class="p">(</span> <span class="s2">"gpt2_base_en"</span><span class="p">,</span> <span class="n">sequence_length</span><span class="o">=</span><span class="mi">128</span><span class="p">,</span> <span class="p">)</span> <span class="n">lora_model</span> <span class="o">=</span> <span class="n">keras_hub</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">GPT2CausalLM</span><span class="o">.</span><span class="n">from_preset</span><span class="p">(</span> <span class="s2">"gpt2_base_en"</span><span class="p">,</span> <span class="n">preprocessor</span><span class="o">=</span><span class="n">preprocessor</span><span class="p">,</span> <span class="p">)</span> </code></pre></div> <p>We will now override the original query/value projection matrices with our new LoRA layers.</p> <div class="codehilite"><pre><span></span><code><span class="k">for</span> <span class="n">layer_idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">lora_model</span><span class="o">.</span><span class="n">backbone</span><span class="o">.</span><span class="n">num_layers</span><span class="p">):</span> <span class="c1"># Change query dense layer.</span> <span class="n">decoder_layer</span> <span class="o">=</span> <span class="n">lora_model</span><span class="o">.</span><span class="n">backbone</span><span class="o">.</span><span class="n">get_layer</span><span class="p">(</span><span class="sa">f</span><span class="s2">"transformer_layer_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> <span class="n">self_attention_layer</span> <span class="o">=</span> <span class="n">decoder_layer</span><span class="o">.</span><span class="n">_self_attention_layer</span> <span class="c1"># Allow mutation to Keras layer state.</span> <span class="n">self_attention_layer</span><span class="o">.</span><span class="n">_tracker</span><span class="o">.</span><span class="n">locked</span> <span class="o">=</span> <span class="kc">False</span> <span class="c1"># Change query dense layer.</span> <span class="n">self_attention_layer</span><span class="o">.</span><span class="n">_query_dense</span> <span class="o">=</span> <span class="n">LoraLayer</span><span class="p">(</span> <span class="n">self_attention_layer</span><span class="o">.</span><span class="n">_query_dense</span><span class="p">,</span> <span class="n">rank</span><span class="o">=</span><span class="n">RANK</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="n">ALPHA</span><span class="p">,</span> <span class="n">trainable</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="p">)</span> <span class="c1"># Change value dense layer.</span> <span class="n">self_attention_layer</span><span class="o">.</span><span class="n">_value_dense</span> <span class="o">=</span> <span class="n">LoraLayer</span><span class="p">(</span> <span class="n">self_attention_layer</span><span class="o">.</span><span class="n">_value_dense</span><span class="p">,</span> <span class="n">rank</span><span class="o">=</span><span class="n">RANK</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="n">ALPHA</span><span class="p">,</span> <span class="n">trainable</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="p">)</span> </code></pre></div> <p>Let's now do a forward pass to make sure we still have a valid chain of computation.</p> <div class="codehilite"><pre><span></span><code><span class="n">lora_model</span><span class="p">(</span><span class="n">preprocessor</span><span class="p">([</span><span class="s2">"LoRA is very useful for quick LLM finetuning"</span><span class="p">])[</span><span class="mi">0</span><span class="p">])</span> <span class="k">pass</span> </code></pre></div> <p>Freeze the entire LLM, only the LoRA layers should be trainable.</p> <div class="codehilite"><pre><span></span><code><span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="n">lora_model</span><span class="o">.</span><span class="n">_flatten_layers</span><span class="p">():</span> <span class="n">lst_of_sublayers</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">_flatten_layers</span><span class="p">())</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">lst_of_sublayers</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span> <span class="c1"># "leaves of the model"</span> <span class="k">if</span> <span class="n">layer</span><span class="o">.</span><span class="n">name</span> <span class="ow">in</span> <span class="p">[</span><span class="s2">"lora_A"</span><span class="p">,</span> <span class="s2">"lora_B"</span><span class="p">]:</span> <span class="n">layer</span><span class="o">.</span><span class="n">trainable</span> <span class="o">=</span> <span class="kc">True</span> <span class="k">else</span><span class="p">:</span> <span class="n">layer</span><span class="o">.</span><span class="n">trainable</span> <span class="o">=</span> <span class="kc">False</span> </code></pre></div> <p>Print the model's summary and see if the number of non-trainable parameters and total parameters are correct.</p> <p>In a previous section, we had calculated the number of parameters associated with the LoRA layers to be 6,144. The total trainable parameters in the model should be <code>num_layers * (query, value) * 6,144 = 12 * 2 * 6,144 = 147,456</code>. The number of non-trainable parameters should be the same as the total number of parameters in the original GPT-2 model, which is <code>124,439,808</code>.</p> <div class="codehilite"><pre><span></span><code><span class="n">lora_model</span><span class="o">.</span><span class="n">summary</span><span class="p">()</span> </code></pre></div> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold">Preprocessor: "gpt2_causal_lm_preprocessor_1"</span> </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃<span style="font-weight: bold"> Tokenizer (type) </span>┃<span style="font-weight: bold"> Vocab # </span>┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ gpt2_tokenizer_1 (<span style="color: #0087ff; text-decoration-color: #0087ff">GPT2Tokenizer</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">50,257</span> │ └────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘ </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold">Model: "gpt2_causal_lm_1"</span> </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃<span style="font-weight: bold"> Layer (type) </span>┃<span style="font-weight: bold"> Output Shape </span>┃<span style="font-weight: bold"> Param # </span>┃<span style="font-weight: bold"> Connected to </span>┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ padding_mask (<span style="color: #0087ff; text-decoration-color: #0087ff">InputLayer</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ - │ ├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤ │ token_ids (<span style="color: #0087ff; text-decoration-color: #0087ff">InputLayer</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ - │ ├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤ │ gpt2_backbone_1 │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">768</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">124,587,264</span> │ padding_mask[<span style="color: #00af00; text-decoration-color: #00af00">0</span>][<span style="color: #00af00; text-decoration-color: #00af00">0</span>], │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">GPT2Backbone</span>) │ │ │ token_ids[<span style="color: #00af00; text-decoration-color: #00af00">0</span>][<span style="color: #00af00; text-decoration-color: #00af00">0</span>] │ ├───────────────────────────────┼───────────────────────────┼─────────────┼────────────────────────────────┤ │ token_embedding │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">50257</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">38,597,376</span> │ gpt2_backbone_1[<span style="color: #00af00; text-decoration-color: #00af00">0</span>][<span style="color: #00af00; text-decoration-color: #00af00">0</span>] │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">ReversibleEmbedding</span>) │ │ │ │ └───────────────────────────────┴───────────────────────────┴─────────────┴────────────────────────────────┘ </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Total params: </span><span style="color: #00af00; text-decoration-color: #00af00">124,587,264</span> (475.26 MB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">147,456</span> (576.00 KB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Non-trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">124,439,808</span> (474.70 MB) </pre> <h3 id="finetune-lora-gpt2">Fine-tune LoRA GPT-2</h3> <p>Now that we have hacked and verified the LoRA GPT-2 model, let's train it!</p> <div class="codehilite"><pre><span></span><code><span class="n">gpu_memory_callback</span> <span class="o">=</span> <span class="n">GPUMemoryCallback</span><span class="p">(</span> <span class="n">target_batches</span><span class="o">=</span><span class="p">[</span><span class="mi">5</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">25</span><span class="p">,</span> <span class="mi">50</span><span class="p">,</span> <span class="mi">100</span><span class="p">,</span> <span class="mi">150</span><span class="p">,</span> <span class="mi">200</span><span class="p">,</span> <span class="mi">300</span><span class="p">,</span> <span class="mi">400</span><span class="p">,</span> <span class="mi">500</span><span class="p">],</span> <span class="n">print_stats</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="p">)</span> <span class="n">optimizer</span><span class="p">,</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">get_optimizer_and_loss</span><span class="p">()</span> <span class="n">lora_model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">optimizer</span><span class="p">,</span> <span class="n">loss</span><span class="o">=</span><span class="n">loss</span><span class="p">,</span> <span class="n">weighted_metrics</span><span class="o">=</span><span class="p">[</span><span class="s2">"accuracy"</span><span class="p">],</span> <span class="p">)</span> <span class="n">lora_model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span> <span class="n">train_ds</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="n">EPOCHS</span><span class="p">,</span> <span class="n">callbacks</span><span class="o">=</span><span class="p">[</span><span class="n">gpu_memory_callback</span><span class="p">],</span> <span class="p">)</span> <span class="n">lora_model_memory_usage</span> <span class="o">=</span> <span class="n">gpu_memory_callback</span><span class="o">.</span><span class="n">memory_usage</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 2/500 [37m━━━━━━━━━━━━━━━━━━━━ 41s 84ms/step - accuracy: 0.2828 - loss: 3.7188 W0000 00:00:1701128576.353742 38699 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update 500/500 ━━━━━━━━━━━━━━━━━━━━ 80s 81ms/step - accuracy: 0.2930 - loss: 3.6158 </code></pre></div> </div> <p>And we are done fine-tuning the model! Before we generate text, let's compare the training time and memory usage of the two models. The training time of GPT-2 on a 16 GB Tesla T4 (Colab) is 7 minutes, and for LoRA, it is 5 minutes, a 30% decrease. The memory usage of LoRA GPT-2 is roughly 35% times less than GPT-2.</p> <div class="codehilite"><pre><span></span><code><span class="n">plt</span><span class="o">.</span><span class="n">bar</span><span class="p">(</span> <span class="p">[</span><span class="s2">"GPT-2"</span><span class="p">,</span> <span class="s2">"LoRA GPT-2"</span><span class="p">],</span> <span class="p">[</span><span class="nb">max</span><span class="p">(</span><span class="n">gpt2_lm_memory_usage</span><span class="p">),</span> <span class="nb">max</span><span class="p">(</span><span class="n">lora_model_memory_usage</span><span class="p">)],</span> <span class="n">color</span><span class="o">=</span><span class="p">[</span><span class="s2">"red"</span><span class="p">,</span> <span class="s2">"blue"</span><span class="p">],</span> <span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s2">"Time"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">ylabel</span><span class="p">(</span><span class="s2">"GPU Memory Usage (in GB)"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="s2">"GPU Memory Usage Comparison"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">legend</span><span class="p">()</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>WARNING:matplotlib.legend:No artists with labels found to put in legend. Note that artists whose label start with an underscore are ignored when legend() is called with no argument. </code></pre></div> </div> <p><img alt="png" src="/img/examples/nlp/parameter_efficient_finetuning_of_gpt2_with_lora/parameter_efficient_finetuning_of_gpt2_with_lora_43_1.png" /></p> <h3 id="merge-weights-and-generate-text">Merge weights and generate text!</h3> <p>One of the biggest advantages of LoRA over other adapter methods is that it does not incur any additional inference latency. Let's understand why.</p> <p>Recall our LoRA equation: <code>output = W0x + b0 + BAx</code>. We can rewrite this as: <code>output = = Wx + b0 = (W0 + BA)x + b0</code>, where <code>W = W0 + BA</code>. This means that if we merge the weights of the original model and the adapter, we will be essentially doing the same computation as the original model!</p> <div class="codehilite"><pre><span></span><code><span class="k">for</span> <span class="n">layer_idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">lora_model</span><span class="o">.</span><span class="n">backbone</span><span class="o">.</span><span class="n">num_layers</span><span class="p">):</span> <span class="n">self_attention_layer</span> <span class="o">=</span> <span class="n">lora_model</span><span class="o">.</span><span class="n">backbone</span><span class="o">.</span><span class="n">get_layer</span><span class="p">(</span> <span class="sa">f</span><span class="s2">"transformer_layer_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s2">"</span> <span class="p">)</span><span class="o">.</span><span class="n">_self_attention_layer</span> <span class="c1"># Merge query dense layer.</span> <span class="n">query_lora_layer</span> <span class="o">=</span> <span class="n">self_attention_layer</span><span class="o">.</span><span class="n">_query_dense</span> <span class="n">A_weights</span> <span class="o">=</span> <span class="n">query_lora_layer</span><span class="o">.</span><span class="n">A</span><span class="o">.</span><span class="n">kernel</span> <span class="c1"># (768, 1) (a, b)</span> <span class="n">B_weights</span> <span class="o">=</span> <span class="n">query_lora_layer</span><span class="o">.</span><span class="n">B</span><span class="o">.</span><span class="n">kernel</span> <span class="c1"># (1, 12, 64) (b, c, d)</span> <span class="n">increment_weights</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s2">"ab,bcd->acd"</span><span class="p">,</span> <span class="n">A_weights</span><span class="p">,</span> <span class="n">B_weights</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">ALPHA</span> <span class="o">/</span> <span class="n">RANK</span><span class="p">)</span> <span class="n">query_lora_layer</span><span class="o">.</span><span class="n">original_layer</span><span class="o">.</span><span class="n">kernel</span><span class="o">.</span><span class="n">assign_add</span><span class="p">(</span><span class="n">increment_weights</span><span class="p">)</span> <span class="c1"># Merge value dense layer.</span> <span class="n">value_lora_layer</span> <span class="o">=</span> <span class="n">self_attention_layer</span><span class="o">.</span><span class="n">_value_dense</span> <span class="n">A_weights</span> <span class="o">=</span> <span class="n">value_lora_layer</span><span class="o">.</span><span class="n">A</span><span class="o">.</span><span class="n">kernel</span> <span class="c1"># (768, 1) (a, b)</span> <span class="n">B_weights</span> <span class="o">=</span> <span class="n">value_lora_layer</span><span class="o">.</span><span class="n">B</span><span class="o">.</span><span class="n">kernel</span> <span class="c1"># (1, 12, 64) (b, c, d)</span> <span class="n">increment_weights</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s2">"ab,bcd->acd"</span><span class="p">,</span> <span class="n">A_weights</span><span class="p">,</span> <span class="n">B_weights</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">ALPHA</span> <span class="o">/</span> <span class="n">RANK</span><span class="p">)</span> <span class="n">value_lora_layer</span><span class="o">.</span><span class="n">original_layer</span><span class="o">.</span><span class="n">kernel</span><span class="o">.</span><span class="n">assign_add</span><span class="p">(</span><span class="n">increment_weights</span><span class="p">)</span> <span class="c1"># Put back in place the original layers with updated weights</span> <span class="n">self_attention_layer</span><span class="o">.</span><span class="n">_query_dense</span> <span class="o">=</span> <span class="n">query_lora_layer</span><span class="o">.</span><span class="n">original_layer</span> <span class="n">self_attention_layer</span><span class="o">.</span><span class="n">_value_dense</span> <span class="o">=</span> <span class="n">value_lora_layer</span><span class="o">.</span><span class="n">original_layer</span> </code></pre></div> <p>We are now all set to generate text with our LoRA model :).</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Freezing weights not necessary during generation since no weights are updated.</span> <span class="n">generate_text</span><span class="p">(</span><span class="n">lora_model</span><span class="p">,</span> <span class="s2">"I like basketball"</span><span class="p">,</span> <span class="n">max_length</span><span class="o">=</span><span class="n">MAX_GENERATION_LENGTH</span><span class="p">)</span> <span class="n">generate_text</span><span class="p">(</span> <span class="n">lora_model</span><span class="p">,</span> <span class="s2">"That Italian restaurant is"</span><span class="p">,</span> <span class="n">max_length</span><span class="o">=</span><span class="n">MAX_GENERATION_LENGTH</span> <span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Output: I like basketball. i've played this game for about a week and i'm pretty tired. today, i'm playing with my friend, who is a really good player. i'm a little older than the average player and i'm a bit too young. Total Time Elapsed: 6.81s </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Output: That Italian restaurant is in the city center and is located on a street that was recently renovated for the summer. </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>i was in a group of friends and had a great time. </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Total Time Elapsed: 0.32s </code></pre></div> </div> <p>And we're all done!</p> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#parameterefficient-finetuning-of-gpt2-with-lora'>Parameter-efficient fine-tuning of GPT-2 with LoRA</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#setup'>Setup</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#dataset'>Dataset</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#helper-functions'>Helper functions</a> </div> <div class='k-outline-depth-3'> <a href='#callback-for-tracking-gpu-memory-usage'>Callback for tracking GPU memory usage</a> </div> <div class='k-outline-depth-3'> <a href='#function-for-text-generation'>Function for text generation</a> </div> <div class='k-outline-depth-3'> <a href='#define-optimizer-and-loss'>Define optimizer and loss</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#finetune-gpt2'>Fine-tune GPT-2</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#lora-gpt2'>LoRA GPT-2</a> </div> <div class='k-outline-depth-3'> <a href='#what-exactly-is-lora'>What exactly is LoRA?</a> </div> <div class='k-outline-depth-3'> <a href='#create-lora-layer'>Create LoRA layer</a> </div> <div class='k-outline-depth-3'> <a href='#inject-lora-layer-into-the-model'>Inject LoRA layer into the model</a> </div> <div class='k-outline-depth-3'> <a href='#finetune-lora-gpt2'>Fine-tune LoRA GPT-2</a> </div> <div class='k-outline-depth-3'> <a href='#merge-weights-and-generate-text'>Merge weights and generate text!</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>