CINXE.COM
Parameter-efficient fine-tuning of Gemma with LoRA and QLoRA
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/keras_recipes/parameter_efficient_finetuning_of_gemma_with_lora_and_qlora/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Parameter-efficient fine-tuning of Gemma with LoRA and QLoRA"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Parameter-efficient fine-tuning of Gemma with LoRA and QLoRA"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Parameter-efficient fine-tuning of Gemma with LoRA and QLoRA</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink active" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-sublink2 active" href="/examples/keras_recipes/parameter_efficient_finetuning_of_gemma_with_lora_and_qlora/">Parameter-efficient fine-tuning of Gemma with LoRA and QLoRA</a> <a class="nav-sublink2" href="/examples/keras_recipes/float8_training_and_inference_with_transformer/">Float8 training and inference with a simple Transformer model</a> <a class="nav-sublink2" href="/examples/keras_recipes/tf_serving/">Serving TensorFlow models with TFServing</a> <a class="nav-sublink2" href="/examples/keras_recipes/debugging_tips/">Keras debugging tips</a> <a class="nav-sublink2" href="/examples/keras_recipes/subclassing_conv_layers/">Customizing the convolution operation of a Conv2D layer</a> <a class="nav-sublink2" href="/examples/keras_recipes/trainer_pattern/">Trainer pattern</a> <a class="nav-sublink2" href="/examples/keras_recipes/endpoint_layer_pattern/">Endpoint layer pattern</a> <a class="nav-sublink2" href="/examples/keras_recipes/reproducibility_recipes/">Reproducibility in Keras Models</a> <a class="nav-sublink2" href="/examples/keras_recipes/tensorflow_numpy_models/">Writing Keras Models With TensorFlow NumPy</a> <a class="nav-sublink2" href="/examples/keras_recipes/antirectifier/">Simple custom layer example: Antirectifier</a> <a class="nav-sublink2" href="/examples/keras_recipes/sample_size_estimate/">Estimating required sample size for model training</a> <a class="nav-sublink2" href="/examples/keras_recipes/memory_efficient_embeddings/">Memory-efficient embeddings for recommendation systems</a> <a class="nav-sublink2" href="/examples/keras_recipes/creating_tfrecords/">Creating TFRecords</a> <a class="nav-sublink2" href="/examples/keras_recipes/packaging_keras_models_for_wide_distribution/">Packaging Keras models for wide distribution using Functional Subclassing</a> <a class="nav-sublink2" href="/examples/keras_recipes/approximating_non_function_mappings/">Approximating non-Function Mappings with Mixture Density Networks</a> <a class="nav-sublink2" href="/examples/keras_recipes/bayesian_neural_networks/">Probabilistic Bayesian Neural Networks</a> <a class="nav-sublink2" href="/examples/keras_recipes/better_knowledge_distillation/">Knowledge distillation recipes</a> <a class="nav-sublink2" href="/examples/keras_recipes/sklearn_metric_callbacks/">Evaluating and exporting scikit-learn metrics in a Keras callback</a> <a class="nav-sublink2" href="/examples/keras_recipes/tfrecord/">How to train a Keras model on TFRecord files</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparameter Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> <a class="nav-link" href="/keras_cv/" role="tab" aria-selected="">KerasCV: Computer Vision Workflows</a> <a class="nav-link" href="/keras_nlp/" role="tab" aria-selected="">KerasNLP: Natural Language Workflows</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/keras_recipes/'>Quick Keras Recipes</a> / Parameter-efficient fine-tuning of Gemma with LoRA and QLoRA </div> <div class='k-content'> <h1 id="parameterefficient-finetuning-of-gemma-with-lora-and-qlora">Parameter-efficient fine-tuning of Gemma with LoRA and QLoRA</h1> <p><strong>Authors:</strong> <a href="https://github.com/james77777778">Hongyu Chiu</a>, <a href="https://github.com/abheesht17/">Abheesht Sharma</a>, <a href="https://github.com/mattdangerw/">Matthew Watson</a><br> <strong>Date created:</strong> 2024/08/06<br> <strong>Last modified:</strong> 2024/08/06<br> <strong>Description:</strong> Use KerasHub to fine-tune a Gemma LLM with LoRA and QLoRA.</p> <div class='example_version_banner keras_3'>ⓘ This example uses Keras 3</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/keras_recipes/ipynb/parameter_efficient_finetuning_of_gemma_with_lora_and_qlora.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/examples/keras_recipes/parameter_efficient_finetuning_of_gemma_with_lora_and_qlora.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="introduction">Introduction</h2> <p>Large Language Models (LLMs) have been shown to be effective at a variety of NLP tasks. An LLM is first pre-trained on a large corpus of text in a self-supervised fashion. Pre-training helps LLMs learn general-purpose knowledge, such as statistical relationships between words. An LLM can then be fine-tuned on a downstream task of interest (such as sentiment analysis).</p> <p>However, LLMs are extremely large in size, and we don't need to train all the parameters in the model while fine-tuning, especially because datasets on which the model is fine-tuned are relatively small. Another way of saying this is that LLMs are over-parametrized for fine-tuning. This is where <a href="https://arxiv.org/abs/2106.09685">Low-Rank Adaptation (LoRA)</a> comes in; it significantly reduces the number of trainable parameters. This results in a decrease in training time and GPU memory usage, while maintaining the quality of the outputs.</p> <p>Furthermore, <a href="https://arxiv.org/abs/2305.14314">Quantized Low-Rank Adaptation (QLoRA)</a> extends LoRA to enhance efficiency through quantization techniques without performance degradation.</p> <p>In this example, we will fine-tune KerasHub's <a href="https://keras.io/api/keras_hub/models/gemma/">Gemma model</a> on the next token prediction task using LoRA and QLoRA.</p> <p>Note that this example runs on all backends supported by Keras. TensorFlow is only used for data preprocessing.</p> <hr /> <h2 id="setup">Setup</h2> <p>Before we start implementing the pipeline, let's install and import all the libraries we need. We'll be using the KerasHub library.</p> <p>Secondly, let's set the precision to bfloat16. This will help us reduce the memory usage and training time.</p> <p>Also, ensure that <code>KAGGLE_USERNAME</code> and <code>KAGGLE_KEY</code> have been correctly configured to access the Gemma model.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># We might need the latest code from Keras and KerasHub</span> <span class="err">!</span><span class="n">pip</span> <span class="n">install</span> <span class="o">-</span><span class="n">q</span> <span class="n">git</span><span class="o">+</span><span class="n">https</span><span class="p">:</span><span class="o">//</span><span class="n">github</span><span class="o">.</span><span class="n">com</span><span class="o">/</span><span class="n">keras</span><span class="o">-</span><span class="n">team</span><span class="o">/</span><span class="n">keras</span><span class="o">.</span><span class="n">git</span> <span class="n">git</span><span class="o">+</span><span class="n">https</span><span class="p">:</span><span class="o">//</span><span class="n">github</span><span class="o">.</span><span class="n">com</span><span class="o">/</span><span class="n">keras</span><span class="o">-</span><span class="n">team</span><span class="o">/</span><span class="n">keras</span><span class="o">-</span><span class="n">hub</span><span class="o">.</span><span class="n">git</span> </code></pre></div> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">gc</span> <span class="kn">import</span> <span class="nn">os</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s2">"KERAS_BACKEND"</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"jax"</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s2">"TF_CPP_MIN_LOG_LEVEL"</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"3"</span> <span class="c1"># Suppress verbose logging from TF</span> <span class="c1"># os.environ["KAGGLE_USERNAME"] = "..."</span> <span class="c1"># os.environ["KAGGLE_KEY"] = "..."</span> <span class="kn">import</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="nn">keras_hub</span> <span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="nn">tf</span> <span class="kn">import</span> <span class="nn">tensorflow_datasets</span> <span class="k">as</span> <span class="nn">tfds</span> <span class="n">keras</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">set_dtype_policy</span><span class="p">(</span><span class="s2">"bfloat16"</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="dataset">Dataset</h2> <p>We will use the MTNT (Machine Translation of Noisy Text) dataset, which is available from TensorFlow Datasets. In this example, we will use the French-to-English portion of the dataset.</p> <div class="codehilite"><pre><span></span><code><span class="n">train_ds</span> <span class="o">=</span> <span class="n">tfds</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s2">"mtnt/fr-en"</span><span class="p">,</span> <span class="n">split</span><span class="o">=</span><span class="s2">"train"</span><span class="p">)</span> </code></pre></div> <p>We can print some samples. Each sample in the dataset contains two entries:</p> <ul> <li>src: the original French sentence.</li> <li>dst: the corresponding English translation.</li> </ul> <div class="codehilite"><pre><span></span><code><span class="n">examples</span> <span class="o">=</span> <span class="n">train_ds</span><span class="o">.</span><span class="n">take</span><span class="p">(</span><span class="mi">3</span><span class="p">)</span> <span class="n">examples</span> <span class="o">=</span> <span class="n">examples</span><span class="o">.</span><span class="n">as_numpy_iterator</span><span class="p">()</span> <span class="k">for</span> <span class="n">idx</span><span class="p">,</span> <span class="n">example</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">examples</span><span class="p">):</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Example </span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s2">:"</span><span class="p">)</span> <span class="k">for</span> <span class="n">key</span><span class="p">,</span> <span class="n">val</span> <span class="ow">in</span> <span class="n">example</span><span class="o">.</span><span class="n">items</span><span class="p">():</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">key</span><span class="si">}</span><span class="s2">: </span><span class="si">{</span><span class="n">val</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> <span class="nb">print</span><span class="p">()</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Example 0: dst: b'Yep, serious...' src: b"Le journal l'est peut-\xc3\xaatre, mais m\xc3\xaame moi qui suit de droite je les trouve limite de temps en temps..." </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Example 1: dst: b'Finally, I explained to you in what context this copy-pasting is relevant: when we are told padamalgame etc.' src: b"Enfin je t'ai expliqu\xc3\xa9 dans quel cadre ce copypasta est pertinent : quand on nous dit padamalgame etc." </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Example 2: dst: b'Gift of Ubiquity: Fran\xc3\xa7ois Baroin is now advisor to the Barclays Bank, mayor, president of the agglomeration, professor at HEC Paris, president of the Association of Mayors of France and Advocate Counselor, it must take him half a day each month.' src: b"Don d'Ubiquit\xc3\xa9 : Fran\xc3\xa7ois Baroin est d\xc3\xa9sormais conseiller \xc3\xa0 la Banque Barclays, maire, pr\xc3\xa9sident d'agglom\xc3\xa9ration, professeur \xc3\xa0 HEC Paris, pr\xc3\xa9sident de l'association des maires de France et avocat Conseiller, \xc3\xa7a doit lui prendre une demi journ\xc3\xa9e par mois." </code></pre></div> </div> <p>Since we will fine-tune our model to perform a French-to-English translation task, we should format the inputs for instruction tuning. For example, we could format the translation task in this example like:</p> <div class="codehilite"><pre><span></span><code><start_of_turn>user Translate French into English: {src}<end_of_turn> <start_of_turn>model {dst}<end_of_turn> </code></pre></div> <p>The special tokens such as <code><start_of_turn>user</code>, <code><start_of_turn>model</code> and <code><end_of_turn></code> are used for Gemma models. You can learn more from https://ai.google.dev/gemma/docs/formatting</p> <div class="codehilite"><pre><span></span><code><span class="n">train_ds</span> <span class="o">=</span> <span class="n">train_ds</span><span class="o">.</span><span class="n">map</span><span class="p">(</span> <span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">tf</span><span class="o">.</span><span class="n">strings</span><span class="o">.</span><span class="n">join</span><span class="p">(</span> <span class="p">[</span> <span class="s2">"<start_of_turn>user</span><span class="se">\n</span><span class="s2">"</span><span class="p">,</span> <span class="s2">"Translate French into English:</span><span class="se">\n</span><span class="s2">"</span><span class="p">,</span> <span class="n">x</span><span class="p">[</span><span class="s2">"src"</span><span class="p">],</span> <span class="s2">"<end_of_turn></span><span class="se">\n</span><span class="s2">"</span><span class="p">,</span> <span class="s2">"<start_of_turn>model</span><span class="se">\n</span><span class="s2">"</span><span class="p">,</span> <span class="s2">"Translation:</span><span class="se">\n</span><span class="s2">"</span><span class="p">,</span> <span class="n">x</span><span class="p">[</span><span class="s2">"dst"</span><span class="p">],</span> <span class="s2">"<end_of_turn>"</span><span class="p">,</span> <span class="p">]</span> <span class="p">)</span> <span class="p">)</span> <span class="n">examples</span> <span class="o">=</span> <span class="n">train_ds</span><span class="o">.</span><span class="n">take</span><span class="p">(</span><span class="mi">3</span><span class="p">)</span> <span class="n">examples</span> <span class="o">=</span> <span class="n">examples</span><span class="o">.</span><span class="n">as_numpy_iterator</span><span class="p">()</span> <span class="k">for</span> <span class="n">idx</span><span class="p">,</span> <span class="n">example</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">examples</span><span class="p">):</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Example </span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s2">:"</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="n">example</span><span class="p">)</span> <span class="nb">print</span><span class="p">()</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Example 0: b"<start_of_turn>user\nTranslate French into English:\nLe journal l'est peut-\xc3\xaatre, mais m\xc3\xaame moi qui suit de droite je les trouve limite de temps en temps...<end_of_turn>\n<start_of_turn>model\nTranslation:\nYep, serious...<end_of_turn>" </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Example 1: b"<start_of_turn>user\nTranslate French into English:\nEnfin je t'ai expliqu\xc3\xa9 dans quel cadre ce copypasta est pertinent : quand on nous dit padamalgame etc.<end_of_turn>\n<start_of_turn>model\nTranslation:\nFinally, I explained to you in what context this copy-pasting is relevant: when we are told padamalgame etc.<end_of_turn>" </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Example 2: b"<start_of_turn>user\nTranslate French into English:\nDon d'Ubiquit\xc3\xa9 : Fran\xc3\xa7ois Baroin est d\xc3\xa9sormais conseiller \xc3\xa0 la Banque Barclays, maire, pr\xc3\xa9sident d'agglom\xc3\xa9ration, professeur \xc3\xa0 HEC Paris, pr\xc3\xa9sident de l'association des maires de France et avocat Conseiller, \xc3\xa7a doit lui prendre une demi journ\xc3\xa9e par mois.<end_of_turn>\n<start_of_turn>model\nTranslation:\nGift of Ubiquity: Fran\xc3\xa7ois Baroin is now advisor to the Barclays Bank, mayor, president of the agglomeration, professor at HEC Paris, president of the Association of Mayors of France and Advocate Counselor, it must take him half a day each month.<end_of_turn>" </code></pre></div> </div> <p>We will take a subset of the dataset for the purpose of this example.</p> <div class="codehilite"><pre><span></span><code><span class="n">train_ds</span> <span class="o">=</span> <span class="n">train_ds</span><span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">take</span><span class="p">(</span><span class="mi">100</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="model">Model</h2> <p>KerasHub provides implementations of many popular model architectures. In this example, we will use <code>GemmaCausalLM</code>, an end-to-end Gemma model for causal language modeling. A causal language model predicts the next token based on previous tokens.</p> <p>Note that <code>sequence_length</code> is set to <code>256</code> to speed up the fitting.</p> <div class="codehilite"><pre><span></span><code><span class="n">preprocessor</span> <span class="o">=</span> <span class="n">keras_hub</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">GemmaCausalLMPreprocessor</span><span class="o">.</span><span class="n">from_preset</span><span class="p">(</span> <span class="s2">"gemma_1.1_instruct_2b_en"</span><span class="p">,</span> <span class="n">sequence_length</span><span class="o">=</span><span class="mi">256</span> <span class="p">)</span> <span class="n">gemma_lm</span> <span class="o">=</span> <span class="n">keras_hub</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">GemmaCausalLM</span><span class="o">.</span><span class="n">from_preset</span><span class="p">(</span> <span class="s2">"gemma_1.1_instruct_2b_en"</span><span class="p">,</span> <span class="n">preprocessor</span><span class="o">=</span><span class="n">preprocessor</span> <span class="p">)</span> <span class="n">gemma_lm</span><span class="o">.</span><span class="n">summary</span><span class="p">()</span> </code></pre></div> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold">Preprocessor: "gemma_causal_lm_preprocessor"</span> </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃<span style="font-weight: bold"> Tokenizer (type) </span>┃<span style="font-weight: bold"> Vocab # </span>┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ gemma_tokenizer (<span style="color: #0087ff; text-decoration-color: #0087ff">GemmaTokenizer</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">256,000</span> │ └────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘ </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold">Model: "gemma_causal_lm"</span> </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃<span style="font-weight: bold"> Layer (type) </span>┃<span style="font-weight: bold"> Output Shape </span>┃<span style="font-weight: bold"> Param # </span>┃<span style="font-weight: bold"> Connected to </span>┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ padding_mask (<span style="color: #0087ff; text-decoration-color: #0087ff">InputLayer</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ - │ ├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ │ token_ids (<span style="color: #0087ff; text-decoration-color: #0087ff">InputLayer</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ - │ ├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ │ gemma_backbone │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">2048</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">2,506,172,416</span> │ padding_mask[<span style="color: #00af00; text-decoration-color: #00af00">0</span>][<span style="color: #00af00; text-decoration-color: #00af00">0</span>], │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">GemmaBackbone</span>) │ │ │ token_ids[<span style="color: #00af00; text-decoration-color: #00af00">0</span>][<span style="color: #00af00; text-decoration-color: #00af00">0</span>] │ ├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ │ token_embedding │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">256000</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">524,288,000</span> │ gemma_backbone[<span style="color: #00af00; text-decoration-color: #00af00">0</span>][<span style="color: #00af00; text-decoration-color: #00af00">0</span>] │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">ReversibleEmbedding</span>) │ │ │ │ └───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘ </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Total params: </span><span style="color: #00af00; text-decoration-color: #00af00">2,506,172,416</span> (4.67 GB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">2,506,172,416</span> (4.67 GB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Non-trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">0</span> (0.00 B) </pre> <hr /> <h2 id="lora-finetuning">LoRA Fine-tuning</h2> <h3 id="what-exactly-is-lora">What exactly is LoRA?</h3> <p>Low-rank adaptation (LoRA) is a parameter-efficient fine-tuning technique for LLMs. It freezes the weights of the LLM, and injects trainable rank-decomposition matrices. Let's understand this more clearly.</p> <p>Assume we have an <code>n x n</code> pre-trained dense layer (or weight matrix), <code>W0</code>. We initialize two dense layers, <code>A</code> and <code>B</code>, of shapes <code>n x rank</code>, and <code>rank x n</code>, respectively. <code>rank</code> is much smaller than <code>n</code>. In the paper, values between 1 and 4 are shown to work well.</p> <h3 id="lora-equation">LoRA equation</h3> <p>The original equation is <code>output = W0x + b0</code>, where <code>x</code> is the input, <code>W0</code> and <code>b0</code> are the weight matrix and bias terms of the original dense layer (frozen). The LoRA equation is: <code>output = W0x + b0 + BAx</code>, where <code>A</code> and <code>B</code> are the rank-decomposition matrices.</p> <p>LoRA is based on the idea that updates to the weights of the pre-trained language model have a low "intrinsic rank" since pre-trained language models are over-parametrized. Predictive performance of full fine-tuning can be replicated even by constraining <code>W0</code>'s updates to low-rank decomposition matrices.</p> <h3 id="number-of-trainable-parameters">Number of trainable parameters</h3> <p>Let's do some quick math. Suppose <code>n</code> is 768, and <code>rank</code> is 4. <code>W0</code> has <code>768 x 768 = 589,824</code> parameters, whereas the LoRA layers, <code>A</code> and <code>B</code> together have <code>768 x 4 + 4 x 768 = 6,144</code> parameters. So, for the dense layer, we go from <code>589,824</code> trainable parameters to <code>6,144</code> trainable parameters!</p> <h3 id="why-does-lora-reduce-memory-footprint">Why does LoRA reduce memory footprint?</h3> <p>Even though the total number of parameters increase (since we are adding LoRA layers), the memory footprint reduces, because the number of trainable parameters reduces. Let's dive deeper into this.</p> <p>The memory usage of a model can be split into four parts:</p> <ul> <li>Model memory: This is the memory required to store the model weights. This will be slightly higher for LoRA than the original model.</li> <li>Forward pass memory: This mostly depends on batch size, sequence length, etc. We keep this constant for both models for a fair comparison.</li> <li>Backward pass memory: This is the memory required to store the gradients. Note that the gradients are computed only for the trainable parameters.</li> <li>Optimizer memory: This is the memory required to store the optimizer state. For example, the Adam optimizer stores the "1st moment vectors" and "2nd moment vectors" for the trainable parameters.</li> </ul> <p>Since, with LoRA, there is a huge reduction in the number of trainable parameters, the optimizer memory and the memory required to store the gradients for LoRA is much less than the original model. This is where most of the memory savings happen.</p> <h3 id="why-is-lora-so-popular">Why is LoRA so popular?</h3> <ul> <li>Reduces GPU memory usage;</li> <li>Faster training; and</li> <li>No additional inference latency.</li> </ul> <p>When using KerasHub, we can enable LoRA with an one-line API: <code>enable_lora(rank=4)</code></p> <p>From <code>gemma_lm.summary()</code>, we can see enabling LoRA reduces the number of trainable parameters significantly (from 2.5 billion to 1.3 million).</p> <div class="codehilite"><pre><span></span><code><span class="n">gemma_lm</span><span class="o">.</span><span class="n">backbone</span><span class="o">.</span><span class="n">enable_lora</span><span class="p">(</span><span class="n">rank</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span> <span class="n">gemma_lm</span><span class="o">.</span><span class="n">summary</span><span class="p">()</span> </code></pre></div> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold">Preprocessor: "gemma_causal_lm_preprocessor"</span> </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃<span style="font-weight: bold"> Tokenizer (type) </span>┃<span style="font-weight: bold"> Vocab # </span>┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ gemma_tokenizer (<span style="color: #0087ff; text-decoration-color: #0087ff">GemmaTokenizer</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">256,000</span> │ └────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘ </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold">Model: "gemma_causal_lm"</span> </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃<span style="font-weight: bold"> Layer (type) </span>┃<span style="font-weight: bold"> Output Shape </span>┃<span style="font-weight: bold"> Param # </span>┃<span style="font-weight: bold"> Connected to </span>┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ padding_mask (<span style="color: #0087ff; text-decoration-color: #0087ff">InputLayer</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ - │ ├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ │ token_ids (<span style="color: #0087ff; text-decoration-color: #0087ff">InputLayer</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ - │ ├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ │ gemma_backbone │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">2048</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">2,507,536,384</span> │ padding_mask[<span style="color: #00af00; text-decoration-color: #00af00">0</span>][<span style="color: #00af00; text-decoration-color: #00af00">0</span>], │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">GemmaBackbone</span>) │ │ │ token_ids[<span style="color: #00af00; text-decoration-color: #00af00">0</span>][<span style="color: #00af00; text-decoration-color: #00af00">0</span>] │ ├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ │ token_embedding │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">256000</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">524,288,000</span> │ gemma_backbone[<span style="color: #00af00; text-decoration-color: #00af00">0</span>][<span style="color: #00af00; text-decoration-color: #00af00">0</span>] │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">ReversibleEmbedding</span>) │ │ │ │ └───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘ </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Total params: </span><span style="color: #00af00; text-decoration-color: #00af00">2,507,536,384</span> (4.67 GB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">1,363,968</span> (2.60 MB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Non-trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">2,506,172,416</span> (4.67 GB) </pre> <p>Let's fine-tune the LoRA model.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># To save memory, use the SGD optimizer instead of the usual AdamW optimizer.</span> <span class="c1"># For this specific example, SGD is more than enough.</span> <span class="n">optimizer</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">SGD</span><span class="p">(</span><span class="n">learning_rate</span><span class="o">=</span><span class="mf">1e-4</span><span class="p">)</span> <span class="n">gemma_lm</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">loss</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">SparseCategoricalCrossentropy</span><span class="p">(</span><span class="n">from_logits</span><span class="o">=</span><span class="kc">True</span><span class="p">),</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">optimizer</span><span class="p">,</span> <span class="n">weighted_metrics</span><span class="o">=</span><span class="p">[</span><span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">SparseCategoricalAccuracy</span><span class="p">()],</span> <span class="p">)</span> <span class="n">gemma_lm</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">train_ds</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> </code></pre></div> <p>After fine-tuning, responses will follow the instructions provided in the prompt.</p> <div class="codehilite"><pre><span></span><code><span class="n">template</span> <span class="o">=</span> <span class="p">(</span> <span class="s2">"<start_of_turn>user</span><span class="se">\n</span><span class="s2">"</span> <span class="s2">"Translate French into English:</span><span class="se">\n</span><span class="s2">"</span> <span class="s2">"</span><span class="si">{inputs}</span><span class="s2">"</span> <span class="s2">"<end_of_turn></span><span class="se">\n</span><span class="s2">"</span> <span class="s2">"<start_of_turn>model</span><span class="se">\n</span><span class="s2">"</span> <span class="s2">"Translation:</span><span class="se">\n</span><span class="s2">"</span> <span class="p">)</span> <span class="n">prompt</span> <span class="o">=</span> <span class="n">template</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">inputs</span><span class="o">=</span><span class="s2">"Bonjour, je m'appelle Morgane."</span><span class="p">)</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">gemma_lm</span><span class="o">.</span><span class="n">generate</span><span class="p">(</span><span class="n">prompt</span><span class="p">,</span> <span class="n">max_length</span><span class="o">=</span><span class="mi">256</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Translation:</span><span class="se">\n</span><span class="s2">"</span><span class="p">,</span> <span class="n">outputs</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="n">prompt</span><span class="p">,</span> <span class="s2">""</span><span class="p">))</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Translation: Hello, my name is Morgane. </code></pre></div> </div> <p>Release memory.</p> <div class="codehilite"><pre><span></span><code><span class="k">del</span> <span class="n">preprocessor</span> <span class="k">del</span> <span class="n">gemma_lm</span> <span class="k">del</span> <span class="n">optimizer</span> <span class="n">gc</span><span class="o">.</span><span class="n">collect</span><span class="p">()</span> </code></pre></div> <hr /> <h2 id="qlora-finetuning">QLoRA Fine-tuning</h2> <p>Quantized Low-Rank Adaptation (QLoRA) extends LoRA to enhance efficiency by quantizing the model weights from high precision data types, such as float32, to lower precision data types like int8. This leads to reduced memory usage and faster computation. The saved model weights are also much smaller.</p> <p>Note that the QLoRA implementation here is a simplified version compared to the original. The differences are:</p> <ul> <li>The 4-bit NormalFloat format is not used because no backend supports it.</li> <li>No double quantization.</li> <li>No Paged optimizer.</li> </ul> <p>To enable QLoRA in KerasHub, follow these steps:</p> <ol> <li>Instantiate the model.</li> <li>Quantize the weights using dynamic int8 quantization.</li> <li>Enable LoRA.</li> </ol> <p>Steps 2 and 3 are achieved with one-line APIs:</p> <ul> <li><code>quantize("int8")</code></li> <li><code>enable_lora(...)</code></li> </ul> <div class="codehilite"><pre><span></span><code><span class="n">preprocessor</span> <span class="o">=</span> <span class="n">keras_hub</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">GemmaCausalLMPreprocessor</span><span class="o">.</span><span class="n">from_preset</span><span class="p">(</span> <span class="s2">"gemma_1.1_instruct_2b_en"</span><span class="p">,</span> <span class="n">sequence_length</span><span class="o">=</span><span class="mi">256</span> <span class="p">)</span> <span class="n">gemma_lm</span> <span class="o">=</span> <span class="n">keras_hub</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">GemmaCausalLM</span><span class="o">.</span><span class="n">from_preset</span><span class="p">(</span> <span class="s2">"gemma_1.1_instruct_2b_en"</span><span class="p">,</span> <span class="n">preprocessor</span><span class="o">=</span><span class="n">preprocessor</span> <span class="p">)</span> <span class="n">gemma_lm</span><span class="o">.</span><span class="n">quantize</span><span class="p">(</span><span class="s2">"int8"</span><span class="p">)</span> <span class="n">gemma_lm</span><span class="o">.</span><span class="n">backbone</span><span class="o">.</span><span class="n">enable_lora</span><span class="p">(</span><span class="n">rank</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span> <span class="n">gemma_lm</span><span class="o">.</span><span class="n">summary</span><span class="p">()</span> </code></pre></div> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold">Preprocessor: "gemma_causal_lm_preprocessor_1"</span> </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃<span style="font-weight: bold"> Tokenizer (type) </span>┃<span style="font-weight: bold"> Vocab # </span>┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ gemma_tokenizer (<span style="color: #0087ff; text-decoration-color: #0087ff">GemmaTokenizer</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">256,000</span> │ └────────────────────────────────────────────────────┴─────────────────────────────────────────────────────┘ </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold">Model: "gemma_causal_lm_1"</span> </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃<span style="font-weight: bold"> Layer (type) </span>┃<span style="font-weight: bold"> Output Shape </span>┃<span style="font-weight: bold"> Param # </span>┃<span style="font-weight: bold"> Connected to </span>┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ padding_mask (<span style="color: #0087ff; text-decoration-color: #0087ff">InputLayer</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ - │ ├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ │ token_ids (<span style="color: #0087ff; text-decoration-color: #0087ff">InputLayer</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ - │ ├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ │ gemma_backbone │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">2048</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">2,508,502,016</span> │ padding_mask[<span style="color: #00af00; text-decoration-color: #00af00">0</span>][<span style="color: #00af00; text-decoration-color: #00af00">0</span>], │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">GemmaBackbone</span>) │ │ │ token_ids[<span style="color: #00af00; text-decoration-color: #00af00">0</span>][<span style="color: #00af00; text-decoration-color: #00af00">0</span>] │ ├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ │ token_embedding │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">256000</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">524,544,000</span> │ gemma_backbone[<span style="color: #00af00; text-decoration-color: #00af00">0</span>][<span style="color: #00af00; text-decoration-color: #00af00">0</span>] │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">ReversibleEmbedding</span>) │ │ │ │ └───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘ </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Total params: </span><span style="color: #00af00; text-decoration-color: #00af00">2,508,502,016</span> (2.34 GB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">1,363,968</span> (2.60 MB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Non-trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">2,507,138,048</span> (2.34 GB) </pre> <p>Let's fine-tune the QLoRA model.</p> <p>If you are using a device with int8 acceleration support, you should see an improvement in the training speed.</p> <div class="codehilite"><pre><span></span><code><span class="n">optimizer</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">SGD</span><span class="p">(</span><span class="n">learning_rate</span><span class="o">=</span><span class="mf">1e-4</span><span class="p">)</span> <span class="n">gemma_lm</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">loss</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">SparseCategoricalCrossentropy</span><span class="p">(</span><span class="n">from_logits</span><span class="o">=</span><span class="kc">True</span><span class="p">),</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">optimizer</span><span class="p">,</span> <span class="n">weighted_metrics</span><span class="o">=</span><span class="p">[</span><span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">SparseCategoricalAccuracy</span><span class="p">()],</span> <span class="p">)</span> <span class="n">gemma_lm</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">train_ds</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> </code></pre></div> <p>You should get a similar output with QLoRA fine-tuning.</p> <div class="codehilite"><pre><span></span><code><span class="n">prompt</span> <span class="o">=</span> <span class="n">template</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">inputs</span><span class="o">=</span><span class="s2">"Bonjour, je m'appelle Morgane."</span><span class="p">)</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">gemma_lm</span><span class="o">.</span><span class="n">generate</span><span class="p">(</span><span class="n">prompt</span><span class="p">,</span> <span class="n">max_length</span><span class="o">=</span><span class="mi">256</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Translation:</span><span class="se">\n</span><span class="s2">"</span><span class="p">,</span> <span class="n">outputs</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="n">prompt</span><span class="p">,</span> <span class="s2">""</span><span class="p">))</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Translation: Hello, my name is Morgane. </code></pre></div> </div> <p>And we're all done!</p> <p>Note that for demonstration purposes, this example fine-tunes the model on a small subset of the dataset for just one epoch and with a low LoRA rank value. To get better responses from the fine-tuned model, you can experiment with:</p> <ul> <li>Increasing the size of the fine-tuning dataset.</li> <li>Training for more steps (epochs).</li> <li>Setting a higher LoRA rank.</li> <li>Modifying the hyperparameter values such as <code>learning_rate</code> and <code>weight_decay</code>.</li> </ul> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#parameterefficient-finetuning-of-gemma-with-lora-and-qlora'>Parameter-efficient fine-tuning of Gemma with LoRA and QLoRA</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#setup'>Setup</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#dataset'>Dataset</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#model'>Model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#lora-finetuning'>LoRA Fine-tuning</a> </div> <div class='k-outline-depth-3'> <a href='#what-exactly-is-lora'>What exactly is LoRA?</a> </div> <div class='k-outline-depth-3'> <a href='#lora-equation'>LoRA equation</a> </div> <div class='k-outline-depth-3'> <a href='#number-of-trainable-parameters'>Number of trainable parameters</a> </div> <div class='k-outline-depth-3'> <a href='#why-does-lora-reduce-memory-footprint'>Why does LoRA reduce memory footprint?</a> </div> <div class='k-outline-depth-3'> <a href='#why-is-lora-so-popular'>Why is LoRA so popular?</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#qlora-finetuning'>QLoRA Fine-tuning</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>