CINXE.COM
Customizing Saving and Serialization
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/guides/customizing_saving_and_serialization/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Customizing Saving and Serialization"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Customizing Saving and Serialization"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Customizing Saving and Serialization</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link active" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-sublink" href="/guides/functional_api/">The Functional API</a> <a class="nav-sublink" href="/guides/sequential_model/">The Sequential model</a> <a class="nav-sublink" href="/guides/making_new_layers_and_models_via_subclassing/">Making new layers & models via subclassing</a> <a class="nav-sublink" href="/guides/training_with_built_in_methods/">Training & evaluation with the built-in methods</a> <a class="nav-sublink" href="/guides/custom_train_step_in_jax/">Customizing `fit()` with JAX</a> <a class="nav-sublink" href="/guides/custom_train_step_in_tensorflow/">Customizing `fit()` with TensorFlow</a> <a class="nav-sublink" href="/guides/custom_train_step_in_torch/">Customizing `fit()` with PyTorch</a> <a class="nav-sublink" href="/guides/writing_a_custom_training_loop_in_jax/">Writing a custom training loop in JAX</a> <a class="nav-sublink" href="/guides/writing_a_custom_training_loop_in_tensorflow/">Writing a custom training loop in TensorFlow</a> <a class="nav-sublink" href="/guides/writing_a_custom_training_loop_in_torch/">Writing a custom training loop in PyTorch</a> <a class="nav-sublink" href="/guides/serialization_and_saving/">Serialization & saving</a> <a class="nav-sublink active" href="/guides/customizing_saving_and_serialization/">Customizing saving & serialization</a> <a class="nav-sublink" href="/guides/writing_your_own_callbacks/">Writing your own callbacks</a> <a class="nav-sublink" href="/guides/transfer_learning/">Transfer learning & fine-tuning</a> <a class="nav-sublink" href="/guides/distributed_training_with_jax/">Distributed training with JAX</a> <a class="nav-sublink" href="/guides/distributed_training_with_tensorflow/">Distributed training with TensorFlow</a> <a class="nav-sublink" href="/guides/distributed_training_with_torch/">Distributed training with PyTorch</a> <a class="nav-sublink" href="/guides/distribution/">Distributed training with Keras 3</a> <a class="nav-sublink" href="/guides/migrating_to_keras_3/">Migrating Keras 2 code to Keras 3</a> <a class="nav-sublink" href="/guides/keras_tuner/">Hyperparameter Tuning</a> <a class="nav-sublink" href="/guides/keras_cv/">KerasCV</a> <a class="nav-sublink" href="/guides/keras_nlp/">KerasNLP</a> <a class="nav-sublink" href="/guides/keras_hub/">KerasHub</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparameter Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> <a class="nav-link" href="/keras_cv/" role="tab" aria-selected="">KerasCV: Computer Vision Workflows</a> <a class="nav-link" href="/keras_nlp/" role="tab" aria-selected="">KerasNLP: Natural Language Workflows</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/guides/'>Developer guides</a> / Customizing Saving and Serialization </div> <div class='k-content'> <h1 id="customizing-saving-and-serialization">Customizing Saving and Serialization</h1> <p><strong>Author:</strong> Neel Kovelamudi<br> <strong>Date created:</strong> 2023/03/15<br> <strong>Last modified:</strong> 2023/03/15<br> <strong>Description:</strong> A more advanced guide on customizing saving for your layers and models.</p> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/guides/ipynb/customizing_saving_and_serialization.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/guides/customizing_saving_and_serialization.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="introduction">Introduction</h2> <p>This guide covers advanced methods that can be customized in Keras saving. For most users, the methods outlined in the primary <a href="https://keras.io/guides/serialization_and_saving">Serialize, save, and export guide</a> are sufficient.</p> <h3 id="apis">APIs</h3> <p>We will cover the following APIs:</p> <ul> <li><code>save_assets()</code> and <code>load_assets()</code></li> <li><code>save_own_variables()</code> and <code>load_own_variables()</code></li> <li><code>get_build_config()</code> and <code>build_from_config()</code></li> <li><code>get_compile_config()</code> and <code>compile_from_config()</code></li> </ul> <p>When restoring a model, these get executed in the following order:</p> <ul> <li><code>build_from_config()</code></li> <li><code>compile_from_config()</code></li> <li><code>load_own_variables()</code></li> <li><code>load_assets()</code></li> </ul> <hr /> <h2 id="setup">Setup</h2> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">os</span> <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> <span class="kn">import</span> <span class="nn">keras</span> </code></pre></div> <hr /> <h2 id="state-saving-customization">State saving customization</h2> <p>These methods determine how the state of your model's layers is saved when calling <code>model.save()</code>. You can override them to take full control of the state saving process.</p> <h3 id="saveownvariables-and-loadownvariables"><code>save_own_variables()</code> and <code>load_own_variables()</code></h3> <p>These methods save and load the state variables of the layer when <code>model.save()</code> and <code>keras.models.load_model()</code> are called, respectively. By default, the state variables saved and loaded are the weights of the layer (both trainable and non-trainable). Here is the default implementation of <code>save_own_variables()</code>:</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">save_own_variables</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">store</span><span class="p">):</span> <span class="n">all_vars</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_trainable_weights</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">_non_trainable_weights</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">all_vars</span><span class="p">):</span> <span class="n">store</span><span class="p">[</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s2">"</span><span class="p">]</span> <span class="o">=</span> <span class="n">v</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> </code></pre></div> <p>The store used by these methods is a dictionary that can be populated with the layer variables. Let's take a look at an example customizing this.</p> <p><strong>Example:</strong></p> <div class="codehilite"><pre><span></span><code><span class="nd">@keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">register_keras_serializable</span><span class="p">(</span><span class="n">package</span><span class="o">=</span><span class="s2">"my_custom_package"</span><span class="p">)</span> <span class="k">class</span> <span class="nc">LayerWithCustomVariable</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">units</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">units</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">my_variable</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Variable</span><span class="p">(</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">((</span><span class="n">units</span><span class="p">,)),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"my_variable"</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"float32"</span> <span class="p">)</span> <span class="k">def</span> <span class="nf">save_own_variables</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">store</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">save_own_variables</span><span class="p">(</span><span class="n">store</span><span class="p">)</span> <span class="c1"># Stores the value of the variable upon saving</span> <span class="n">store</span><span class="p">[</span><span class="s2">"variables"</span><span class="p">]</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">my_variable</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> <span class="k">def</span> <span class="nf">load_own_variables</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">store</span><span class="p">):</span> <span class="c1"># Assigns the value of the variable upon loading</span> <span class="bp">self</span><span class="o">.</span><span class="n">my_variable</span><span class="o">.</span><span class="n">assign</span><span class="p">(</span><span class="n">store</span><span class="p">[</span><span class="s2">"variables"</span><span class="p">])</span> <span class="c1"># Load the remaining weights</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">weights</span><span class="p">):</span> <span class="n">v</span><span class="o">.</span><span class="n">assign</span><span class="p">(</span><span class="n">store</span><span class="p">[</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s2">"</span><span class="p">])</span> <span class="c1"># Note: You must specify how all variables (including layer weights)</span> <span class="c1"># are loaded in `load_own_variables.`</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="n">dense_out</span> <span class="o">=</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">call</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="k">return</span> <span class="n">dense_out</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">my_variable</span> <span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">([</span><span class="n">LayerWithCustomVariable</span><span class="p">(</span><span class="mi">1</span><span class="p">)])</span> <span class="n">ref_input</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">((</span><span class="mi">8</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span> <span class="n">ref_output</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">((</span><span class="mi">8</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">optimizer</span><span class="o">=</span><span class="s2">"adam"</span><span class="p">,</span> <span class="n">loss</span><span class="o">=</span><span class="s2">"mean_squared_error"</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">ref_input</span><span class="p">,</span> <span class="n">ref_output</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="s2">"custom_vars_model.keras"</span><span class="p">)</span> <span class="n">restored_model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">load_model</span><span class="p">(</span><span class="s2">"custom_vars_model.keras"</span><span class="p">)</span> <span class="n">np</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_allclose</span><span class="p">(</span> <span class="n">model</span><span class="o">.</span><span class="n">layers</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">my_variable</span><span class="o">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">restored_model</span><span class="o">.</span><span class="n">layers</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">my_variable</span><span class="o">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 101ms/step - loss: 0.2908 </code></pre></div> </div> <h3 id="saveassets-and-loadassets"><code>save_assets()</code> and <code>load_assets()</code></h3> <p>These methods can be added to your model class definition to store and load any additional information that your model needs.</p> <p>For example, NLP domain layers such as TextVectorization layers and IndexLookup layers may need to store their associated vocabulary (or lookup table) in a text file upon saving.</p> <p>Let's take at the basics of this workflow with a simple file <code>assets.txt</code>.</p> <p><strong>Example:</strong></p> <div class="codehilite"><pre><span></span><code><span class="nd">@keras</span><span class="o">.</span><span class="n">saving</span><span class="o">.</span><span class="n">register_keras_serializable</span><span class="p">(</span><span class="n">package</span><span class="o">=</span><span class="s2">"my_custom_package"</span><span class="p">)</span> <span class="k">class</span> <span class="nc">LayerWithCustomAssets</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">vocab</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab</span> <span class="o">=</span> <span class="n">vocab</span> <span class="k">def</span> <span class="nf">save_assets</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inner_path</span><span class="p">):</span> <span class="c1"># Writes the vocab (sentence) to text file at save time.</span> <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">inner_path</span><span class="p">,</span> <span class="s2">"vocabulary.txt"</span><span class="p">),</span> <span class="s2">"w"</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span> <span class="n">f</span><span class="o">.</span><span class="n">write</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">vocab</span><span class="p">)</span> <span class="k">def</span> <span class="nf">load_assets</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inner_path</span><span class="p">):</span> <span class="c1"># Reads the vocab (sentence) from text file at load time.</span> <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">inner_path</span><span class="p">,</span> <span class="s2">"vocabulary.txt"</span><span class="p">),</span> <span class="s2">"r"</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span> <span class="n">text</span> <span class="o">=</span> <span class="n">f</span><span class="o">.</span><span class="n">read</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">vocab</span> <span class="o">=</span> <span class="n">text</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">"<unk>"</span><span class="p">,</span> <span class="s2">"little"</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span> <span class="p">[</span><span class="n">LayerWithCustomAssets</span><span class="p">(</span><span class="n">vocab</span><span class="o">=</span><span class="s2">"Mary had a <unk> lamb."</span><span class="p">,</span> <span class="n">units</span><span class="o">=</span><span class="mi">5</span><span class="p">)]</span> <span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">((</span><span class="mi">10</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span> <span class="n">y</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="s2">"custom_assets_model.keras"</span><span class="p">)</span> <span class="n">restored_model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">load_model</span><span class="p">(</span><span class="s2">"custom_assets_model.keras"</span><span class="p">)</span> <span class="n">np</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_string_equal</span><span class="p">(</span> <span class="n">restored_model</span><span class="o">.</span><span class="n">layers</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">vocab</span><span class="p">,</span> <span class="s2">"Mary had a little lamb."</span> <span class="p">)</span> </code></pre></div> <hr /> <h2 id="build-and-compile-saving-customization"><code>build</code> and <code>compile</code> saving customization</h2> <h3 id="getbuildconfig-and-buildfromconfig"><code>get_build_config()</code> and <code>build_from_config()</code></h3> <p>These methods work together to save the layer's built states and restore them upon loading.</p> <p>By default, this only includes a build config dictionary with the layer's input shape, but overriding these methods can be used to include further Variables and Lookup Tables that can be useful to restore for your built model.</p> <p><strong>Example:</strong></p> <div class="codehilite"><pre><span></span><code><span class="nd">@keras</span><span class="o">.</span><span class="n">saving</span><span class="o">.</span><span class="n">register_keras_serializable</span><span class="p">(</span><span class="n">package</span><span class="o">=</span><span class="s2">"my_custom_package"</span><span class="p">)</span> <span class="k">class</span> <span class="nc">LayerWithCustomBuild</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">units</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">units</span> <span class="o">=</span> <span class="n">units</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="k">return</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">w</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">b</span> <span class="k">def</span> <span class="nf">get_config</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="k">return</span> <span class="nb">dict</span><span class="p">(</span><span class="n">units</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">units</span><span class="p">,</span> <span class="o">**</span><span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">get_config</span><span class="p">())</span> <span class="k">def</span> <span class="nf">build</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_shape</span><span class="p">,</span> <span class="n">layer_init</span><span class="p">):</span> <span class="c1"># Note the overriding of `build()` to add an extra argument.</span> <span class="c1"># Therefore, we will need to manually call build with `layer_init` argument</span> <span class="c1"># before the first execution of `call()`.</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">build</span><span class="p">(</span><span class="n">input_shape</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">_input_shape</span> <span class="o">=</span> <span class="n">input_shape</span> <span class="bp">self</span><span class="o">.</span><span class="n">w</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">add_weight</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">input_shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">units</span><span class="p">),</span> <span class="n">initializer</span><span class="o">=</span><span class="n">layer_init</span><span class="p">,</span> <span class="n">trainable</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">b</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">add_weight</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">units</span><span class="p">,),</span> <span class="n">initializer</span><span class="o">=</span><span class="n">layer_init</span><span class="p">,</span> <span class="n">trainable</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_init</span> <span class="o">=</span> <span class="n">layer_init</span> <span class="k">def</span> <span class="nf">get_build_config</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="n">build_config</span> <span class="o">=</span> <span class="p">{</span> <span class="s2">"layer_init"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_init</span><span class="p">,</span> <span class="s2">"input_shape"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">_input_shape</span><span class="p">,</span> <span class="p">}</span> <span class="c1"># Stores our initializer for `build()`</span> <span class="k">return</span> <span class="n">build_config</span> <span class="k">def</span> <span class="nf">build_from_config</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">config</span><span class="p">):</span> <span class="c1"># Calls `build()` with the parameters at loading time</span> <span class="bp">self</span><span class="o">.</span><span class="n">build</span><span class="p">(</span><span class="n">config</span><span class="p">[</span><span class="s2">"input_shape"</span><span class="p">],</span> <span class="n">config</span><span class="p">[</span><span class="s2">"layer_init"</span><span class="p">])</span> <span class="n">custom_layer</span> <span class="o">=</span> <span class="n">LayerWithCustomBuild</span><span class="p">(</span><span class="n">units</span><span class="o">=</span><span class="mi">16</span><span class="p">)</span> <span class="n">custom_layer</span><span class="o">.</span><span class="n">build</span><span class="p">(</span><span class="n">input_shape</span><span class="o">=</span><span class="p">(</span><span class="mi">8</span><span class="p">,),</span> <span class="n">layer_init</span><span class="o">=</span><span class="s2">"random_normal"</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span> <span class="p">[</span> <span class="n">custom_layer</span><span class="p">,</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"sigmoid"</span><span class="p">),</span> <span class="p">]</span> <span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">((</span><span class="mi">16</span><span class="p">,</span> <span class="mi">8</span><span class="p">))</span> <span class="n">y</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="s2">"custom_build_model.keras"</span><span class="p">)</span> <span class="n">restored_model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">load_model</span><span class="p">(</span><span class="s2">"custom_build_model.keras"</span><span class="p">)</span> <span class="n">np</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_equal</span><span class="p">(</span><span class="n">restored_model</span><span class="o">.</span><span class="n">layers</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">layer_init</span><span class="p">,</span> <span class="s2">"random_normal"</span><span class="p">)</span> <span class="n">np</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_equal</span><span class="p">(</span><span class="n">restored_model</span><span class="o">.</span><span class="n">built</span><span class="p">,</span> <span class="kc">True</span><span class="p">)</span> </code></pre></div> <h3 id="getcompileconfig-and-compilefromconfig"><code>get_compile_config()</code> and <code>compile_from_config()</code></h3> <p>These methods work together to save the information with which the model was compiled (optimizers, losses, etc.) and restore and re-compile the model with this information.</p> <p>Overriding these methods can be useful for compiling the restored model with custom optimizers, custom losses, etc., as these will need to be deserialized prior to calling <code>model.compile</code> in <code>compile_from_config()</code>.</p> <p>Let's take a look at an example of this.</p> <p><strong>Example:</strong></p> <div class="codehilite"><pre><span></span><code><span class="nd">@keras</span><span class="o">.</span><span class="n">saving</span><span class="o">.</span><span class="n">register_keras_serializable</span><span class="p">(</span><span class="n">package</span><span class="o">=</span><span class="s2">"my_custom_package"</span><span class="p">)</span> <span class="k">def</span> <span class="nf">small_square_sum_loss</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">):</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">square</span><span class="p">(</span><span class="n">y_pred</span> <span class="o">-</span> <span class="n">y_true</span><span class="p">)</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">loss</span> <span class="o">/</span> <span class="mf">10.0</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">loss</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="k">return</span> <span class="n">loss</span> <span class="nd">@keras</span><span class="o">.</span><span class="n">saving</span><span class="o">.</span><span class="n">register_keras_serializable</span><span class="p">(</span><span class="n">package</span><span class="o">=</span><span class="s2">"my_custom_package"</span><span class="p">)</span> <span class="k">def</span> <span class="nf">mean_pred</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">):</span> <span class="k">return</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">y_pred</span><span class="p">)</span> <span class="nd">@keras</span><span class="o">.</span><span class="n">saving</span><span class="o">.</span><span class="n">register_keras_serializable</span><span class="p">(</span><span class="n">package</span><span class="o">=</span><span class="s2">"my_custom_package"</span><span class="p">)</span> <span class="k">class</span> <span class="nc">ModelWithCustomCompile</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense1</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">8</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense2</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"softmax"</span><span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense1</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">def</span> <span class="nf">compile</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">,</span> <span class="n">loss_fn</span><span class="p">,</span> <span class="n">metrics</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">optimizer</span><span class="o">=</span><span class="n">optimizer</span><span class="p">,</span> <span class="n">loss</span><span class="o">=</span><span class="n">loss_fn</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="n">metrics</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_optimizer</span> <span class="o">=</span> <span class="n">optimizer</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_fn</span> <span class="o">=</span> <span class="n">loss_fn</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_metrics</span> <span class="o">=</span> <span class="n">metrics</span> <span class="k">def</span> <span class="nf">get_compile_config</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="c1"># These parameters will be serialized at saving time.</span> <span class="k">return</span> <span class="p">{</span> <span class="s2">"model_optimizer"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">model_optimizer</span><span class="p">,</span> <span class="s2">"loss_fn"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_fn</span><span class="p">,</span> <span class="s2">"metric"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_metrics</span><span class="p">,</span> <span class="p">}</span> <span class="k">def</span> <span class="nf">compile_from_config</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">config</span><span class="p">):</span> <span class="c1"># Deserializes the compile parameters (important, since many are custom)</span> <span class="n">optimizer</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">deserialize_keras_object</span><span class="p">(</span><span class="n">config</span><span class="p">[</span><span class="s2">"model_optimizer"</span><span class="p">])</span> <span class="n">loss_fn</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">deserialize_keras_object</span><span class="p">(</span><span class="n">config</span><span class="p">[</span><span class="s2">"loss_fn"</span><span class="p">])</span> <span class="n">metrics</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">deserialize_keras_object</span><span class="p">(</span><span class="n">config</span><span class="p">[</span><span class="s2">"metric"</span><span class="p">])</span> <span class="c1"># Calls compile with the deserialized parameters</span> <span class="bp">self</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">optimizer</span><span class="o">=</span><span class="n">optimizer</span><span class="p">,</span> <span class="n">loss_fn</span><span class="o">=</span><span class="n">loss_fn</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="n">metrics</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">ModelWithCustomCompile</span><span class="p">()</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">optimizer</span><span class="o">=</span><span class="s2">"SGD"</span><span class="p">,</span> <span class="n">loss_fn</span><span class="o">=</span><span class="n">small_square_sum_loss</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="s2">"accuracy"</span><span class="p">,</span> <span class="n">mean_pred</span><span class="p">]</span> <span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">((</span><span class="mi">4</span><span class="p">,</span> <span class="mi">8</span><span class="p">))</span> <span class="n">y</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">((</span><span class="mi">4</span><span class="p">,))</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="s2">"custom_compile_model.keras"</span><span class="p">)</span> <span class="n">restored_model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">load_model</span><span class="p">(</span><span class="s2">"custom_compile_model.keras"</span><span class="p">)</span> <span class="n">np</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_equal</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">model_optimizer</span><span class="p">,</span> <span class="n">restored_model</span><span class="o">.</span><span class="n">model_optimizer</span><span class="p">)</span> <span class="n">np</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_equal</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">loss_fn</span><span class="p">,</span> <span class="n">restored_model</span><span class="o">.</span><span class="n">loss_fn</span><span class="p">)</span> <span class="n">np</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_equal</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">loss_metrics</span><span class="p">,</span> <span class="n">restored_model</span><span class="o">.</span><span class="n">loss_metrics</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 79ms/step - accuracy: 0.0000e+00 - loss: 0.0627 - mean_metric_wrapper: 0.2500 </code></pre></div> </div> <hr /> <h2 id="conclusion">Conclusion</h2> <p>Using the methods learned in this tutorial allows for a wide variety of use cases, allowing the saving and loading of complex models with exotic assets and state elements. To recap:</p> <ul> <li><code>save_own_variables</code> and <code>load_own_variables</code> determine how your states are saved and loaded.</li> <li><code>save_assets</code> and <code>load_assets</code> can be added to store and load any additional information your model needs.</li> <li><code>get_build_config</code> and <code>build_from_config</code> save and restore the model's built states.</li> <li><code>get_compile_config</code> and <code>compile_from_config</code> save and restore the model's compiled states.</li> </ul> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#customizing-saving-and-serialization'>Customizing Saving and Serialization</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-3'> <a href='#apis'>APIs</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#setup'>Setup</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#state-saving-customization'>State saving customization</a> </div> <div class='k-outline-depth-3'> <a href='#saveownvariables-and-loadownvariables'><code>save_own_variables()</code> and <code>load_own_variables()</code></a> </div> <div class='k-outline-depth-3'> <a href='#saveassets-and-loadassets'><code>save_assets()</code> and <code>load_assets()</code></a> </div> <div class='k-outline-depth-2'> ◆ <a href='#build-and-compile-saving-customization'><code>build</code> and <code>compile</code> saving customization</a> </div> <div class='k-outline-depth-3'> <a href='#getbuildconfig-and-buildfromconfig'><code>get_build_config()</code> and <code>build_from_config()</code></a> </div> <div class='k-outline-depth-3'> <a href='#getcompileconfig-and-compilefromconfig'><code>get_compile_config()</code> and <code>compile_from_config()</code></a> </div> <div class='k-outline-depth-2'> ◆ <a href='#conclusion'>Conclusion</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>