CINXE.COM
Save, serialize, and export models
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/guides/serialization_and_saving/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Save, serialize, and export models"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Save, serialize, and export models"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Save, serialize, and export models</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link active" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-sublink" href="/guides/functional_api/">The Functional API</a> <a class="nav-sublink" href="/guides/sequential_model/">The Sequential model</a> <a class="nav-sublink" href="/guides/making_new_layers_and_models_via_subclassing/">Making new layers & models via subclassing</a> <a class="nav-sublink" href="/guides/training_with_built_in_methods/">Training & evaluation with the built-in methods</a> <a class="nav-sublink" href="/guides/custom_train_step_in_jax/">Customizing `fit()` with JAX</a> <a class="nav-sublink" href="/guides/custom_train_step_in_tensorflow/">Customizing `fit()` with TensorFlow</a> <a class="nav-sublink" href="/guides/custom_train_step_in_torch/">Customizing `fit()` with PyTorch</a> <a class="nav-sublink" href="/guides/writing_a_custom_training_loop_in_jax/">Writing a custom training loop in JAX</a> <a class="nav-sublink" href="/guides/writing_a_custom_training_loop_in_tensorflow/">Writing a custom training loop in TensorFlow</a> <a class="nav-sublink" href="/guides/writing_a_custom_training_loop_in_torch/">Writing a custom training loop in PyTorch</a> <a class="nav-sublink active" href="/guides/serialization_and_saving/">Serialization & saving</a> <a class="nav-sublink" href="/guides/customizing_saving_and_serialization/">Customizing saving & serialization</a> <a class="nav-sublink" href="/guides/writing_your_own_callbacks/">Writing your own callbacks</a> <a class="nav-sublink" href="/guides/transfer_learning/">Transfer learning & fine-tuning</a> <a class="nav-sublink" href="/guides/distributed_training_with_jax/">Distributed training with JAX</a> <a class="nav-sublink" href="/guides/distributed_training_with_tensorflow/">Distributed training with TensorFlow</a> <a class="nav-sublink" href="/guides/distributed_training_with_torch/">Distributed training with PyTorch</a> <a class="nav-sublink" href="/guides/distribution/">Distributed training with Keras 3</a> <a class="nav-sublink" href="/guides/migrating_to_keras_3/">Migrating Keras 2 code to Keras 3</a> <a class="nav-sublink" href="/guides/keras_tuner/">Hyperparameter Tuning</a> <a class="nav-sublink" href="/guides/keras_cv/">KerasCV</a> <a class="nav-sublink" href="/guides/keras_nlp/">KerasNLP</a> <a class="nav-sublink" href="/guides/keras_hub/">KerasHub</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparameter Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> <a class="nav-link" href="/keras_cv/" role="tab" aria-selected="">KerasCV: Computer Vision Workflows</a> <a class="nav-link" href="/keras_nlp/" role="tab" aria-selected="">KerasNLP: Natural Language Workflows</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/guides/'>Developer guides</a> / Save, serialize, and export models </div> <div class='k-content'> <h1 id="save-serialize-and-export-models">Save, serialize, and export models</h1> <p><strong>Authors:</strong> Neel Kovelamudi, Francois Chollet<br> <strong>Date created:</strong> 2023/06/14<br> <strong>Last modified:</strong> 2023/06/30<br> <strong>Description:</strong> Complete guide to saving, serializing, and exporting models.</p> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/guides/ipynb/serialization_and_saving.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/guides/serialization_and_saving.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="introduction">Introduction</h2> <p>A Keras model consists of multiple components:</p> <ul> <li>The architecture, or configuration, which specifies what layers the model contain, and how they're connected.</li> <li>A set of weights values (the "state of the model").</li> <li>An optimizer (defined by compiling the model).</li> <li>A set of losses and metrics (defined by compiling the model).</li> </ul> <p>The Keras API saves all of these pieces together in a unified format, marked by the <code>.keras</code> extension. This is a zip archive consisting of the following:</p> <ul> <li>A JSON-based configuration file (config.json): Records of model, layer, and other trackables' configuration.</li> <li>A H5-based state file, such as <code>model.weights.h5</code> (for the whole model), with directory keys for layers and their weights.</li> <li>A metadata file in JSON, storing things such as the current Keras version.</li> </ul> <p>Let's take a look at how this works.</p> <hr /> <h2 id="how-to-save-and-load-a-model">How to save and load a model</h2> <p>If you only have 10 seconds to read this guide, here's what you need to know.</p> <p><strong>Saving a Keras model:</strong></p> <div class="codehilite"><pre><span></span><code><span class="n">model</span> <span class="o">=</span> <span class="o">...</span> <span class="c1"># Get model (Sequential, Functional Model, or Model subclass)</span> <span class="n">model</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="s1">'path/to/location.keras'</span><span class="p">)</span> <span class="c1"># The file needs to end with the .keras extension</span> </code></pre></div> <p><strong>Loading the model back:</strong></p> <div class="codehilite"><pre><span></span><code><span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">load_model</span><span class="p">(</span><span class="s1">'path/to/location.keras'</span><span class="p">)</span> </code></pre></div> <p>Now, let's look at the details.</p> <hr /> <h2 id="setup">Setup</h2> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> <span class="kn">import</span> <span class="nn">keras</span> <span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">ops</span> </code></pre></div> <hr /> <h2 id="saving">Saving</h2> <p>This section is about saving an entire model to a single file. The file will include:</p> <ul> <li>The model's architecture/config</li> <li>The model's weight values (which were learned during training)</li> <li>The model's compilation information (if <code>compile()</code> was called)</li> <li>The optimizer and its state, if any (this enables you to restart training where you left)</li> </ul> <h4 id="apis">APIs</h4> <p>You can save a model with <code>model.save()</code> or <code>keras.models.save_model()</code> (which is equivalent). You can load it back with <code>keras.models.load_model()</code>.</p> <p>The only supported format in Keras 3 is the "Keras v3" format, which uses the <code>.keras</code> extension.</p> <p><strong>Example:</strong></p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">get_model</span><span class="p">():</span> <span class="c1"># Create a simple model.</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">32</span><span class="p">,))</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">1</span><span class="p">)(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">optimizer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">Adam</span><span class="p">(),</span> <span class="n">loss</span><span class="o">=</span><span class="s2">"mean_squared_error"</span><span class="p">)</span> <span class="k">return</span> <span class="n">model</span> <span class="n">model</span> <span class="o">=</span> <span class="n">get_model</span><span class="p">()</span> <span class="c1"># Train the model.</span> <span class="n">test_input</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">((</span><span class="mi">128</span><span class="p">,</span> <span class="mi">32</span><span class="p">))</span> <span class="n">test_target</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">((</span><span class="mi">128</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">test_input</span><span class="p">,</span> <span class="n">test_target</span><span class="p">)</span> <span class="c1"># Calling `save('my_model.keras')` creates a zip archive `my_model.keras`.</span> <span class="n">model</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="s2">"my_model.keras"</span><span class="p">)</span> <span class="c1"># It can be used to reconstruct the model identically.</span> <span class="n">reconstructed_model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">load_model</span><span class="p">(</span><span class="s2">"my_model.keras"</span><span class="p">)</span> <span class="c1"># Let's check:</span> <span class="n">np</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_allclose</span><span class="p">(</span> <span class="n">model</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">test_input</span><span class="p">),</span> <span class="n">reconstructed_model</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">test_input</span><span class="p">)</span> <span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 4/4 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.4232 4/4 ━━━━━━━━━━━━━━━━━━━━ 0s 281us/step 4/4 ━━━━━━━━━━━━━━━━━━━━ 0s 373us/step </code></pre></div> </div> <h3 id="custom-objects">Custom objects</h3> <p>This section covers the basic workflows for handling custom layers, functions, and models in Keras saving and reloading.</p> <p>When saving a model that includes custom objects, such as a subclassed Layer, you <strong>must</strong> define a <code>get_config()</code> method on the object class. If the arguments passed to the constructor (<code>__init__()</code> method) of the custom object aren't Python objects (anything other than base types like ints, strings, etc.), then you <strong>must</strong> also explicitly deserialize these arguments in the <code>from_config()</code> class method.</p> <p>Like this:</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">CustomLayer</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sublayer</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">sublayer</span> <span class="o">=</span> <span class="n">sublayer</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">sublayer</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">def</span> <span class="nf">get_config</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="n">base_config</span> <span class="o">=</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">get_config</span><span class="p">()</span> <span class="n">config</span> <span class="o">=</span> <span class="p">{</span> <span class="s2">"sublayer"</span><span class="p">:</span> <span class="n">keras</span><span class="o">.</span><span class="n">saving</span><span class="o">.</span><span class="n">serialize_keras_object</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">sublayer</span><span class="p">),</span> <span class="p">}</span> <span class="k">return</span> <span class="p">{</span><span class="o">**</span><span class="n">base_config</span><span class="p">,</span> <span class="o">**</span><span class="n">config</span><span class="p">}</span> <span class="nd">@classmethod</span> <span class="k">def</span> <span class="nf">from_config</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">config</span><span class="p">):</span> <span class="n">sublayer_config</span> <span class="o">=</span> <span class="n">config</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s2">"sublayer"</span><span class="p">)</span> <span class="n">sublayer</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">saving</span><span class="o">.</span><span class="n">deserialize_keras_object</span><span class="p">(</span><span class="n">sublayer_config</span><span class="p">)</span> <span class="k">return</span> <span class="bp">cls</span><span class="p">(</span><span class="n">sublayer</span><span class="p">,</span> <span class="o">**</span><span class="n">config</span><span class="p">)</span> </code></pre></div> <p>Please see the <a href="#config_methods">Defining the config methods section</a> for more details and examples.</p> <p>The saved <code>.keras</code> file is lightweight and does not store the Python code for custom objects. Therefore, to reload the model, <code>load_model</code> requires access to the definition of any custom objects used through one of the following methods:</p> <ol> <li>Registering custom objects <strong>(preferred)</strong>,</li> <li>Passing custom objects directly when loading, or</li> <li>Using a custom object scope</li> </ol> <p>Below are examples of each workflow:</p> <h4 id="preferred">Registering custom objects (<strong>preferred</strong>)</h4> <p>This is the preferred method, as custom object registration greatly simplifies saving and loading code. Adding the <code>@keras.saving.register_keras_serializable</code> decorator to the class definition of a custom object registers the object globally in a master list, allowing Keras to recognize the object when loading the model.</p> <p>Let's create a custom model involving both a custom layer and a custom activation function to demonstrate this.</p> <p><strong>Example:</strong></p> <div class="codehilite"><pre><span></span><code><span class="c1"># Clear all previously registered custom objects</span> <span class="n">keras</span><span class="o">.</span><span class="n">saving</span><span class="o">.</span><span class="n">get_custom_objects</span><span class="p">()</span><span class="o">.</span><span class="n">clear</span><span class="p">()</span> <span class="c1"># Upon registration, you can optionally specify a package or a name.</span> <span class="c1"># If left blank, the package defaults to `Custom` and the name defaults to</span> <span class="c1"># the class name.</span> <span class="nd">@keras</span><span class="o">.</span><span class="n">saving</span><span class="o">.</span><span class="n">register_keras_serializable</span><span class="p">(</span><span class="n">package</span><span class="o">=</span><span class="s2">"MyLayers"</span><span class="p">)</span> <span class="k">class</span> <span class="nc">CustomLayer</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">factor</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">factor</span> <span class="o">=</span> <span class="n">factor</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="k">return</span> <span class="n">x</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">factor</span> <span class="k">def</span> <span class="nf">get_config</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="k">return</span> <span class="p">{</span><span class="s2">"factor"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">factor</span><span class="p">}</span> <span class="nd">@keras</span><span class="o">.</span><span class="n">saving</span><span class="o">.</span><span class="n">register_keras_serializable</span><span class="p">(</span><span class="n">package</span><span class="o">=</span><span class="s2">"my_package"</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"custom_fn"</span><span class="p">)</span> <span class="k">def</span> <span class="nf">custom_fn</span><span class="p">(</span><span class="n">x</span><span class="p">):</span> <span class="k">return</span> <span class="n">x</span><span class="o">**</span><span class="mi">2</span> <span class="c1"># Create the model.</span> <span class="k">def</span> <span class="nf">get_model</span><span class="p">():</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">4</span><span class="p">,))</span> <span class="n">mid</span> <span class="o">=</span> <span class="n">CustomLayer</span><span class="p">(</span><span class="mf">0.5</span><span class="p">)(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="n">custom_fn</span><span class="p">)(</span><span class="n">mid</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">optimizer</span><span class="o">=</span><span class="s2">"rmsprop"</span><span class="p">,</span> <span class="n">loss</span><span class="o">=</span><span class="s2">"mean_squared_error"</span><span class="p">)</span> <span class="k">return</span> <span class="n">model</span> <span class="c1"># Train the model.</span> <span class="k">def</span> <span class="nf">train_model</span><span class="p">(</span><span class="n">model</span><span class="p">):</span> <span class="nb">input</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">((</span><span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">))</span> <span class="n">target</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">((</span><span class="mi">4</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="nb">input</span><span class="p">,</span> <span class="n">target</span><span class="p">)</span> <span class="k">return</span> <span class="n">model</span> <span class="n">test_input</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">((</span><span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">))</span> <span class="n">test_target</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">((</span><span class="mi">4</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="n">model</span> <span class="o">=</span> <span class="n">get_model</span><span class="p">()</span> <span class="n">model</span> <span class="o">=</span> <span class="n">train_model</span><span class="p">(</span><span class="n">model</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="s2">"custom_model.keras"</span><span class="p">)</span> <span class="c1"># Now, we can simply load without worrying about our custom objects.</span> <span class="n">reconstructed_model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">load_model</span><span class="p">(</span><span class="s2">"custom_model.keras"</span><span class="p">)</span> <span class="c1"># Let's check:</span> <span class="n">np</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_allclose</span><span class="p">(</span> <span class="n">model</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">test_input</span><span class="p">),</span> <span class="n">reconstructed_model</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">test_input</span><span class="p">)</span> <span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 46ms/step - loss: 0.2571 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step </code></pre></div> </div> <h4 id="passing-custom-objects-to-loadmodel">Passing custom objects to <code>load_model()</code></h4> <div class="codehilite"><pre><span></span><code><span class="n">model</span> <span class="o">=</span> <span class="n">get_model</span><span class="p">()</span> <span class="n">model</span> <span class="o">=</span> <span class="n">train_model</span><span class="p">(</span><span class="n">model</span><span class="p">)</span> <span class="c1"># Calling `save('my_model.keras')` creates a zip archive `my_model.keras`.</span> <span class="n">model</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="s2">"custom_model.keras"</span><span class="p">)</span> <span class="c1"># Upon loading, pass a dict containing the custom objects used in the</span> <span class="c1"># `custom_objects` argument of `keras.models.load_model()`.</span> <span class="n">reconstructed_model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">load_model</span><span class="p">(</span> <span class="s2">"custom_model.keras"</span><span class="p">,</span> <span class="n">custom_objects</span><span class="o">=</span><span class="p">{</span><span class="s2">"CustomLayer"</span><span class="p">:</span> <span class="n">CustomLayer</span><span class="p">,</span> <span class="s2">"custom_fn"</span><span class="p">:</span> <span class="n">custom_fn</span><span class="p">},</span> <span class="p">)</span> <span class="c1"># Let's check:</span> <span class="n">np</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_allclose</span><span class="p">(</span> <span class="n">model</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">test_input</span><span class="p">),</span> <span class="n">reconstructed_model</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">test_input</span><span class="p">)</span> <span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 37ms/step - loss: 0.0535 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step </code></pre></div> </div> <h4 id="using-a-custom-object-scope">Using a custom object scope</h4> <p>Any code within the custom object scope will be able to recognize the custom objects passed to the scope argument. Therefore, loading the model within the scope will allow the loading of our custom objects.</p> <p><strong>Example:</strong></p> <div class="codehilite"><pre><span></span><code><span class="n">model</span> <span class="o">=</span> <span class="n">get_model</span><span class="p">()</span> <span class="n">model</span> <span class="o">=</span> <span class="n">train_model</span><span class="p">(</span><span class="n">model</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="s2">"custom_model.keras"</span><span class="p">)</span> <span class="c1"># Pass the custom objects dictionary to a custom object scope and place</span> <span class="c1"># the `keras.models.load_model()` call within the scope.</span> <span class="n">custom_objects</span> <span class="o">=</span> <span class="p">{</span><span class="s2">"CustomLayer"</span><span class="p">:</span> <span class="n">CustomLayer</span><span class="p">,</span> <span class="s2">"custom_fn"</span><span class="p">:</span> <span class="n">custom_fn</span><span class="p">}</span> <span class="k">with</span> <span class="n">keras</span><span class="o">.</span><span class="n">saving</span><span class="o">.</span><span class="n">custom_object_scope</span><span class="p">(</span><span class="n">custom_objects</span><span class="p">):</span> <span class="n">reconstructed_model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">load_model</span><span class="p">(</span><span class="s2">"custom_model.keras"</span><span class="p">)</span> <span class="c1"># Let's check:</span> <span class="n">np</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_allclose</span><span class="p">(</span> <span class="n">model</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">test_input</span><span class="p">),</span> <span class="n">reconstructed_model</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">test_input</span><span class="p">)</span> <span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 40ms/step - loss: 0.0868 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step </code></pre></div> </div> <h3 id="model-serialization">Model serialization</h3> <p>This section is about saving only the model's configuration, without its state. The model's configuration (or architecture) specifies what layers the model contains, and how these layers are connected. If you have the configuration of a model, then the model can be created with a freshly initialized state (no weights or compilation information).</p> <h4 id="apis">APIs</h4> <p>The following serialization APIs are available:</p> <ul> <li><code>keras.models.clone_model(model)</code>: make a (randomly initialized) copy of a model.</li> <li><code>get_config()</code> and <code>cls.from_config()</code>: retrieve the configuration of a layer or model, and recreate a model instance from its config, respectively.</li> <li><code>keras.models.model_to_json()</code> and <code>keras.models.model_from_json()</code>: similar, but as JSON strings.</li> <li><code>keras.saving.serialize_keras_object()</code>: retrieve the configuration any arbitrary Keras object.</li> <li><code>keras.saving.deserialize_keras_object()</code>: recreate an object instance from its configuration.</li> </ul> <h4 id="inmemory-model-cloning">In-memory model cloning</h4> <p>You can do in-memory cloning of a model via <code>keras.models.clone_model()</code>. This is equivalent to getting the config then recreating the model from its config (so it does not preserve compilation information or layer weights values).</p> <p><strong>Example:</strong></p> <div class="codehilite"><pre><span></span><code><span class="n">new_model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">clone_model</span><span class="p">(</span><span class="n">model</span><span class="p">)</span> </code></pre></div> <h4 id="getconfig-and-fromconfig"><code>get_config()</code> and <code>from_config()</code></h4> <p>Calling <code>model.get_config()</code> or <code>layer.get_config()</code> will return a Python dict containing the configuration of the model or layer, respectively. You should define <code>get_config()</code> to contain arguments needed for the <code>__init__()</code> method of the model or layer. At loading time, the <code>from_config(config)</code> method will then call <code>__init__()</code> with these arguments to reconstruct the model or layer.</p> <p><strong>Layer example:</strong></p> <div class="codehilite"><pre><span></span><code><span class="n">layer</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)</span> <span class="n">layer_config</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">get_config</span><span class="p">()</span> <span class="nb">print</span><span class="p">(</span><span class="n">layer_config</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>{'name': 'dense_4', 'trainable': True, 'dtype': 'float32', 'units': 3, 'activation': 'relu', 'use_bias': True, 'kernel_initializer': {'module': 'keras.src.initializers.random_initializers', 'class_name': 'GlorotUniform', 'config': {'seed': None}, 'registered_name': 'GlorotUniform'}, 'bias_initializer': {'module': 'keras.src.initializers.constant_initializers', 'class_name': 'Zeros', 'config': {}, 'registered_name': 'Zeros'}, 'kernel_regularizer': None, 'bias_regularizer': None, 'kernel_constraint': None, 'bias_constraint': None} </code></pre></div> </div> <p>Now let's reconstruct the layer using the <code>from_config()</code> method:</p> <div class="codehilite"><pre><span></span><code><span class="n">new_layer</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="o">.</span><span class="n">from_config</span><span class="p">(</span><span class="n">layer_config</span><span class="p">)</span> </code></pre></div> <p><strong>Sequential model example:</strong></p> <div class="codehilite"><pre><span></span><code><span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">([</span><span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">((</span><span class="mi">32</span><span class="p">,)),</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">1</span><span class="p">)])</span> <span class="n">config</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">get_config</span><span class="p">()</span> <span class="n">new_model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="o">.</span><span class="n">from_config</span><span class="p">(</span><span class="n">config</span><span class="p">)</span> </code></pre></div> <p><strong>Functional model example:</strong></p> <div class="codehilite"><pre><span></span><code><span class="n">inputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">((</span><span class="mi">32</span><span class="p">,))</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">1</span><span class="p">)(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="p">)</span> <span class="n">config</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">get_config</span><span class="p">()</span> <span class="n">new_model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="o">.</span><span class="n">from_config</span><span class="p">(</span><span class="n">config</span><span class="p">)</span> </code></pre></div> <h4 id="tojson-and-kerasmodelsmodelfromjson"><code>to_json()</code> and <code>keras.models.model_from_json()</code></h4> <p>This is similar to <code>get_config</code> / <code>from_config</code>, except it turns the model into a JSON string, which can then be loaded without the original model class. It is also specific to models, it isn't meant for layers.</p> <p><strong>Example:</strong></p> <div class="codehilite"><pre><span></span><code><span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">([</span><span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">((</span><span class="mi">32</span><span class="p">,)),</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">1</span><span class="p">)])</span> <span class="n">json_config</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">to_json</span><span class="p">()</span> <span class="n">new_model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">model_from_json</span><span class="p">(</span><span class="n">json_config</span><span class="p">)</span> </code></pre></div> <h4 id="arbitrary-object-serialization-and-deserialization">Arbitrary object serialization and deserialization</h4> <p>The <code>keras.saving.serialize_keras_object()</code> and <code>keras.saving.deserialize_keras_object()</code> APIs are general-purpose APIs that can be used to serialize or deserialize any Keras object and any custom object. It is at the foundation of saving model architecture and is behind all <code>serialize()</code>/<code>deserialize()</code> calls in keras.</p> <p><strong>Example</strong>:</p> <div class="codehilite"><pre><span></span><code><span class="n">my_reg</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">regularizers</span><span class="o">.</span><span class="n">L1</span><span class="p">(</span><span class="mf">0.005</span><span class="p">)</span> <span class="n">config</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">saving</span><span class="o">.</span><span class="n">serialize_keras_object</span><span class="p">(</span><span class="n">my_reg</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="n">config</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>{'module': 'keras.src.regularizers.regularizers', 'class_name': 'L1', 'config': {'l1': 0.004999999888241291}, 'registered_name': 'L1'} </code></pre></div> </div> <p>Note the serialization format containing all the necessary information for proper reconstruction:</p> <ul> <li><code>module</code> containing the name of the Keras module or other identifying module the object comes from</li> <li><code>class_name</code> containing the name of the object's class.</li> <li><code>config</code> with all the information needed to reconstruct the object</li> <li><code>registered_name</code> for custom objects. See <a href="#custom_object_serialization">here</a>.</li> </ul> <p>Now we can reconstruct the regularizer.</p> <div class="codehilite"><pre><span></span><code><span class="n">new_reg</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">saving</span><span class="o">.</span><span class="n">deserialize_keras_object</span><span class="p">(</span><span class="n">config</span><span class="p">)</span> </code></pre></div> <h3 id="model-weights-saving">Model weights saving</h3> <p>You can choose to only save & load a model's weights. This can be useful if:</p> <ul> <li>You only need the model for inference: in this case you won't need to restart training, so you don't need the compilation information or optimizer state.</li> <li>You are doing transfer learning: in this case you will be training a new model reusing the state of a prior model, so you don't need the compilation information of the prior model.</li> </ul> <h4 id="apis-for-inmemory-weight-transfer">APIs for in-memory weight transfer</h4> <p>Weights can be copied between different objects by using <code>get_weights()</code> and <code>set_weights()</code>:</p> <ul> <li><code>keras.layers.Layer.get_weights()</code>: Returns a list of NumPy arrays of weight values.</li> <li><code>keras.layers.Layer.set_weights(weights)</code>: Sets the model weights to the values provided (as NumPy arrays).</li> </ul> <p>Examples:</p> <p><strong><em>Transferring weights from one layer to another, in memory</em></strong></p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">create_layer</span><span class="p">():</span> <span class="n">layer</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"dense_2"</span><span class="p">)</span> <span class="n">layer</span><span class="o">.</span><span class="n">build</span><span class="p">((</span><span class="kc">None</span><span class="p">,</span> <span class="mi">784</span><span class="p">))</span> <span class="k">return</span> <span class="n">layer</span> <span class="n">layer_1</span> <span class="o">=</span> <span class="n">create_layer</span><span class="p">()</span> <span class="n">layer_2</span> <span class="o">=</span> <span class="n">create_layer</span><span class="p">()</span> <span class="c1"># Copy weights from layer 1 to layer 2</span> <span class="n">layer_2</span><span class="o">.</span><span class="n">set_weights</span><span class="p">(</span><span class="n">layer_1</span><span class="o">.</span><span class="n">get_weights</span><span class="p">())</span> </code></pre></div> <p><strong><em>Transferring weights from one model to another model with a compatible architecture, in memory</em></strong></p> <div class="codehilite"><pre><span></span><code><span class="c1"># Create a simple functional model</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">784</span><span class="p">,),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"digits"</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"dense_1"</span><span class="p">)(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"dense_2"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"predictions"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">functional_model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="o">=</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="o">=</span><span class="n">outputs</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"3_layer_mlp"</span><span class="p">)</span> <span class="c1"># Define a subclassed model with the same architecture</span> <span class="k">class</span> <span class="nc">SubclassedModel</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">output_dim</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_dim</span> <span class="o">=</span> <span class="n">output_dim</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense_1</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"dense_1"</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense_2</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"dense_2"</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense_3</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">output_dim</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"predictions"</span><span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense_1</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense_2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense_3</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">x</span> <span class="k">def</span> <span class="nf">get_config</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="k">return</span> <span class="p">{</span><span class="s2">"output_dim"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_dim</span><span class="p">,</span> <span class="s2">"name"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">}</span> <span class="n">subclassed_model</span> <span class="o">=</span> <span class="n">SubclassedModel</span><span class="p">(</span><span class="mi">10</span><span class="p">)</span> <span class="c1"># Call the subclassed model once to create the weights.</span> <span class="n">subclassed_model</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">784</span><span class="p">)))</span> <span class="c1"># Copy weights from functional_model to subclassed_model.</span> <span class="n">subclassed_model</span><span class="o">.</span><span class="n">set_weights</span><span class="p">(</span><span class="n">functional_model</span><span class="o">.</span><span class="n">get_weights</span><span class="p">())</span> <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">functional_model</span><span class="o">.</span><span class="n">weights</span><span class="p">)</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span><span class="n">subclassed_model</span><span class="o">.</span><span class="n">weights</span><span class="p">)</span> <span class="k">for</span> <span class="n">a</span><span class="p">,</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">functional_model</span><span class="o">.</span><span class="n">weights</span><span class="p">,</span> <span class="n">subclassed_model</span><span class="o">.</span><span class="n">weights</span><span class="p">):</span> <span class="n">np</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_allclose</span><span class="p">(</span><span class="n">a</span><span class="o">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">b</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span> </code></pre></div> <p><strong><em>The case of stateless layers</em></strong></p> <p>Because stateless layers do not change the order or number of weights, models can have compatible architectures even if there are extra/missing stateless layers.</p> <div class="codehilite"><pre><span></span><code><span class="n">inputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">784</span><span class="p">,),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"digits"</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"dense_1"</span><span class="p">)(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"dense_2"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"predictions"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">functional_model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="o">=</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="o">=</span><span class="n">outputs</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"3_layer_mlp"</span><span class="p">)</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">784</span><span class="p">,),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"digits"</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"dense_1"</span><span class="p">)(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"dense_2"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># Add a dropout layer, which does not contain any weights.</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="mf">0.5</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"predictions"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">functional_model_with_dropout</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span> <span class="n">inputs</span><span class="o">=</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="o">=</span><span class="n">outputs</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"3_layer_mlp"</span> <span class="p">)</span> <span class="n">functional_model_with_dropout</span><span class="o">.</span><span class="n">set_weights</span><span class="p">(</span><span class="n">functional_model</span><span class="o">.</span><span class="n">get_weights</span><span class="p">())</span> </code></pre></div> <h4 id="apis-for-saving-weights-to-disk-amp-loading-them-back">APIs for saving weights to disk & loading them back</h4> <p>Weights can be saved to disk by calling <code>model.save_weights(filepath)</code>. The filename should end in <code>.weights.h5</code>.</p> <p><strong>Example:</strong></p> <div class="codehilite"><pre><span></span><code><span class="c1"># Runnable example</span> <span class="n">sequential_model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span> <span class="p">[</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">784</span><span class="p">,),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"digits"</span><span class="p">),</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"dense_1"</span><span class="p">),</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"dense_2"</span><span class="p">),</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"predictions"</span><span class="p">),</span> <span class="p">]</span> <span class="p">)</span> <span class="n">sequential_model</span><span class="o">.</span><span class="n">save_weights</span><span class="p">(</span><span class="s2">"my_model.weights.h5"</span><span class="p">)</span> <span class="n">sequential_model</span><span class="o">.</span><span class="n">load_weights</span><span class="p">(</span><span class="s2">"my_model.weights.h5"</span><span class="p">)</span> </code></pre></div> <p>Note that changing <code>layer.trainable</code> may result in a different <code>layer.weights</code> ordering when the model contains nested layers.</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">NestedDenseLayer</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">units</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense_1</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">units</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"dense_1"</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense_2</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">units</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"dense_2"</span><span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense_2</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dense_1</span><span class="p">(</span><span class="n">inputs</span><span class="p">))</span> <span class="n">nested_model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">([</span><span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">((</span><span class="mi">784</span><span class="p">,)),</span> <span class="n">NestedDenseLayer</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="s2">"nested"</span><span class="p">)])</span> <span class="n">variable_names</span> <span class="o">=</span> <span class="p">[</span><span class="n">v</span><span class="o">.</span><span class="n">name</span> <span class="k">for</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">nested_model</span><span class="o">.</span><span class="n">weights</span><span class="p">]</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"variables: </span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">variable_names</span><span class="p">))</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"</span><span class="se">\n</span><span class="s2">Changing trainable status of one of the nested layers..."</span><span class="p">)</span> <span class="n">nested_model</span><span class="o">.</span><span class="n">get_layer</span><span class="p">(</span><span class="s2">"nested"</span><span class="p">)</span><span class="o">.</span><span class="n">dense_1</span><span class="o">.</span><span class="n">trainable</span> <span class="o">=</span> <span class="kc">False</span> <span class="n">variable_names_2</span> <span class="o">=</span> <span class="p">[</span><span class="n">v</span><span class="o">.</span><span class="n">name</span> <span class="k">for</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">nested_model</span><span class="o">.</span><span class="n">weights</span><span class="p">]</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"</span><span class="se">\n</span><span class="s2">variables: </span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">variable_names_2</span><span class="p">))</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"variable ordering changed:"</span><span class="p">,</span> <span class="n">variable_names</span> <span class="o">!=</span> <span class="n">variable_names_2</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>variables: ['kernel', 'bias', 'kernel', 'bias'] </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Changing trainable status of one of the nested layers... </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>variables: ['kernel', 'bias', 'kernel', 'bias'] variable ordering changed: False </code></pre></div> </div> <h5><strong>Transfer learning example</strong></h5> <p>When loading pretrained weights from a weights file, it is recommended to load the weights into the original checkpointed model, and then extract the desired weights/layers into a new model.</p> <p><strong>Example:</strong></p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">create_functional_model</span><span class="p">():</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">784</span><span class="p">,),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"digits"</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"dense_1"</span><span class="p">)(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"dense_2"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"predictions"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="o">=</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="o">=</span><span class="n">outputs</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"3_layer_mlp"</span><span class="p">)</span> <span class="n">functional_model</span> <span class="o">=</span> <span class="n">create_functional_model</span><span class="p">()</span> <span class="n">functional_model</span><span class="o">.</span><span class="n">save_weights</span><span class="p">(</span><span class="s2">"pretrained.weights.h5"</span><span class="p">)</span> <span class="c1"># In a separate program:</span> <span class="n">pretrained_model</span> <span class="o">=</span> <span class="n">create_functional_model</span><span class="p">()</span> <span class="n">pretrained_model</span><span class="o">.</span><span class="n">load_weights</span><span class="p">(</span><span class="s2">"pretrained.weights.h5"</span><span class="p">)</span> <span class="c1"># Create a new model by extracting layers from the original model:</span> <span class="n">extracted_layers</span> <span class="o">=</span> <span class="n">pretrained_model</span><span class="o">.</span><span class="n">layers</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="n">extracted_layers</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"dense_3"</span><span class="p">))</span> <span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span><span class="n">extracted_layers</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">summary</span><span class="p">()</span> </code></pre></div> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold">Model: "sequential_4"</span> </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ ┃<span style="font-weight: bold"> Layer (type) </span>┃<span style="font-weight: bold"> Output Shape </span>┃<span style="font-weight: bold"> Param # </span>┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ │ dense_1 (<span style="color: #0087ff; text-decoration-color: #0087ff">Dense</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">50,240</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ dense_2 (<span style="color: #0087ff; text-decoration-color: #0087ff">Dense</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">4,160</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ dense_3 (<span style="color: #0087ff; text-decoration-color: #0087ff">Dense</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">5</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">325</span> │ └─────────────────────────────────┴───────────────────────────┴────────────┘ </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Total params: </span><span style="color: #00af00; text-decoration-color: #00af00">54,725</span> (213.77 KB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">54,725</span> (213.77 KB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Non-trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">0</span> (0.00 B) </pre> <h3 id="appendix-handling-custom-objects">Appendix: Handling custom objects</h3> <p><a name="config_methods"></a></p> <h4 id="defining-the-config-methods">Defining the config methods</h4> <p>Specifications:</p> <ul> <li><code>get_config()</code> should return a JSON-serializable dictionary in order to be compatible with the Keras architecture- and model-saving APIs.</li> <li><code>from_config(config)</code> (a <code>classmethod</code>) should return a new layer or model object that is created from the config. The default implementation returns <code>cls(**config)</code>.</li> </ul> <p><strong>NOTE</strong>: If all your constructor arguments are already serializable, e.g. strings and ints, or non-custom Keras objects, overriding <code>from_config</code> is not necessary. However, for more complex objects such as layers or models passed to <code>__init__</code>, deserialization must be handled explicitly either in <code>__init__</code> itself or overriding the <code>from_config()</code> method.</p> <p><strong>Example:</strong></p> <div class="codehilite"><pre><span></span><code><span class="nd">@keras</span><span class="o">.</span><span class="n">saving</span><span class="o">.</span><span class="n">register_keras_serializable</span><span class="p">(</span><span class="n">package</span><span class="o">=</span><span class="s2">"MyLayers"</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"KernelMult"</span><span class="p">)</span> <span class="k">class</span> <span class="nc">MyDense</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> <span class="bp">self</span><span class="p">,</span> <span class="n">units</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span> <span class="n">kernel_regularizer</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">kernel_initializer</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">nested_model</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span> <span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_units</span> <span class="o">=</span> <span class="n">units</span> <span class="bp">self</span><span class="o">.</span><span class="n">kernel_regularizer</span> <span class="o">=</span> <span class="n">kernel_regularizer</span> <span class="bp">self</span><span class="o">.</span><span class="n">kernel_initializer</span> <span class="o">=</span> <span class="n">kernel_initializer</span> <span class="bp">self</span><span class="o">.</span><span class="n">nested_model</span> <span class="o">=</span> <span class="n">nested_model</span> <span class="k">def</span> <span class="nf">get_config</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="n">config</span> <span class="o">=</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">get_config</span><span class="p">()</span> <span class="c1"># Update the config with the custom layer's parameters</span> <span class="n">config</span><span class="o">.</span><span class="n">update</span><span class="p">(</span> <span class="p">{</span> <span class="s2">"units"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_units</span><span class="p">,</span> <span class="s2">"kernel_regularizer"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">kernel_regularizer</span><span class="p">,</span> <span class="s2">"kernel_initializer"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">kernel_initializer</span><span class="p">,</span> <span class="s2">"nested_model"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">nested_model</span><span class="p">,</span> <span class="p">}</span> <span class="p">)</span> <span class="k">return</span> <span class="n">config</span> <span class="k">def</span> <span class="nf">build</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_shape</span><span class="p">):</span> <span class="n">input_units</span> <span class="o">=</span> <span class="n">input_shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="bp">self</span><span class="o">.</span><span class="n">kernel</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">add_weight</span><span class="p">(</span> <span class="n">name</span><span class="o">=</span><span class="s2">"kernel"</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">input_units</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_units</span><span class="p">),</span> <span class="n">regularizer</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">kernel_regularizer</span><span class="p">,</span> <span class="n">initializer</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">kernel_initializer</span><span class="p">,</span> <span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="k">return</span> <span class="n">ops</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">kernel</span><span class="p">)</span> <span class="n">layer</span> <span class="o">=</span> <span class="n">MyDense</span><span class="p">(</span><span class="n">units</span><span class="o">=</span><span class="mi">16</span><span class="p">,</span> <span class="n">kernel_regularizer</span><span class="o">=</span><span class="s2">"l1"</span><span class="p">,</span> <span class="n">kernel_initializer</span><span class="o">=</span><span class="s2">"ones"</span><span class="p">)</span> <span class="n">layer3</span> <span class="o">=</span> <span class="n">MyDense</span><span class="p">(</span><span class="n">units</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">nested_model</span><span class="o">=</span><span class="n">layer</span><span class="p">)</span> <span class="n">config</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">serialize</span><span class="p">(</span><span class="n">layer3</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="n">config</span><span class="p">)</span> <span class="n">new_layer</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">deserialize</span><span class="p">(</span><span class="n">config</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="n">new_layer</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>{'module': None, 'class_name': 'MyDense', 'config': {'name': 'my_dense_1', 'trainable': True, 'dtype': 'float32', 'units': 64, 'kernel_regularizer': None, 'kernel_initializer': None, 'nested_model': {'module': None, 'class_name': 'MyDense', 'config': {'name': 'my_dense', 'trainable': True, 'dtype': 'float32', 'units': 16, 'kernel_regularizer': 'l1', 'kernel_initializer': 'ones', 'nested_model': None}, 'registered_name': 'MyLayers>KernelMult'}}, 'registered_name': 'MyLayers>KernelMult'} <MyDense name=my_dense_1, built=False> </code></pre></div> </div> <p>Note that overriding <code>from_config</code> is unnecessary above for <code>MyDense</code> because <code>hidden_units</code>, <code>kernel_initializer</code>, and <code>kernel_regularizer</code> are ints, strings, and a built-in Keras object, respectively. This means that the default <code>from_config</code> implementation of <code>cls(**config)</code> will work as intended.</p> <p>For more complex objects, such as layers and models passed to <code>__init__</code>, for example, you must explicitly deserialize these objects. Let's take a look at an example of a model where a <code>from_config</code> override is necessary.</p> <p><strong>Example:</strong> <a name="registration_example"></a></p> <div class="codehilite"><pre><span></span><code><span class="nd">@keras</span><span class="o">.</span><span class="n">saving</span><span class="o">.</span><span class="n">register_keras_serializable</span><span class="p">(</span><span class="n">package</span><span class="o">=</span><span class="s2">"ComplexModels"</span><span class="p">)</span> <span class="k">class</span> <span class="nc">CustomModel</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">first_layer</span><span class="p">,</span> <span class="n">second_layer</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span> <span class="o">=</span> <span class="n">first_layer</span> <span class="k">if</span> <span class="n">second_layer</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">second_layer</span> <span class="o">=</span> <span class="n">second_layer</span> <span class="k">else</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">second_layer</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">8</span><span class="p">)</span> <span class="k">def</span> <span class="nf">get_config</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="n">config</span> <span class="o">=</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">get_config</span><span class="p">()</span> <span class="n">config</span><span class="o">.</span><span class="n">update</span><span class="p">(</span> <span class="p">{</span> <span class="s2">"first_layer"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">,</span> <span class="s2">"second_layer"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">second_layer</span><span class="p">,</span> <span class="p">}</span> <span class="p">)</span> <span class="k">return</span> <span class="n">config</span> <span class="nd">@classmethod</span> <span class="k">def</span> <span class="nf">from_config</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">config</span><span class="p">):</span> <span class="c1"># Note that you can also use [`keras.saving.deserialize_keras_object`](/api/models/model_saving_apis/serialization_utils#deserializekerasobject-function) here</span> <span class="n">config</span><span class="p">[</span><span class="s2">"first_layer"</span><span class="p">]</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">deserialize</span><span class="p">(</span><span class="n">config</span><span class="p">[</span><span class="s2">"first_layer"</span><span class="p">])</span> <span class="n">config</span><span class="p">[</span><span class="s2">"second_layer"</span><span class="p">]</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">deserialize</span><span class="p">(</span><span class="n">config</span><span class="p">[</span><span class="s2">"second_layer"</span><span class="p">])</span> <span class="k">return</span> <span class="bp">cls</span><span class="p">(</span><span class="o">**</span><span class="n">config</span><span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">first_layer</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">second_layer</span><span class="p">(</span><span class="n">inputs</span><span class="p">))</span> <span class="c1"># Let's make our first layer the custom layer from the previous example (MyDense)</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">((</span><span class="mi">32</span><span class="p">,))</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">CustomModel</span><span class="p">(</span><span class="n">first_layer</span><span class="o">=</span><span class="n">layer</span><span class="p">)(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="p">)</span> <span class="n">config</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">get_config</span><span class="p">()</span> <span class="n">new_model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="o">.</span><span class="n">from_config</span><span class="p">(</span><span class="n">config</span><span class="p">)</span> </code></pre></div> <p><a name="custom_object_serialization"></a></p> <h4 id="how-custom-objects-are-serialized">How custom objects are serialized</h4> <p>The serialization format has a special key for custom objects registered via <code>@keras.saving.register_keras_serializable</code>. This <code>registered_name</code> key allows for easy retrieval at loading/deserialization time while also allowing users to add custom naming.</p> <p>Let's take a look at the config from serializing the custom layer <code>MyDense</code> we defined above.</p> <p><strong>Example</strong>:</p> <div class="codehilite"><pre><span></span><code><span class="n">layer</span> <span class="o">=</span> <span class="n">MyDense</span><span class="p">(</span> <span class="n">units</span><span class="o">=</span><span class="mi">16</span><span class="p">,</span> <span class="n">kernel_regularizer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">regularizers</span><span class="o">.</span><span class="n">L1L2</span><span class="p">(</span><span class="n">l1</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">,</span> <span class="n">l2</span><span class="o">=</span><span class="mf">1e-4</span><span class="p">),</span> <span class="n">kernel_initializer</span><span class="o">=</span><span class="s2">"ones"</span><span class="p">,</span> <span class="p">)</span> <span class="n">config</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">serialize</span><span class="p">(</span><span class="n">layer</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="n">config</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>{'module': None, 'class_name': 'MyDense', 'config': {'name': 'my_dense_2', 'trainable': True, 'dtype': 'float32', 'units': 16, 'kernel_regularizer': {'module': 'keras.src.regularizers.regularizers', 'class_name': 'L1L2', 'config': {'l1': 1e-05, 'l2': 0.0001}, 'registered_name': 'L1L2'}, 'kernel_initializer': 'ones', 'nested_model': None}, 'registered_name': 'MyLayers>KernelMult'} </code></pre></div> </div> <p>As shown, the <code>registered_name</code> key contains the lookup information for the Keras master list, including the package <code>MyLayers</code> and the custom name <code>KernelMult</code> that we gave in the <code>@keras.saving.register_keras_serializable</code> decorator. Take a look again at the custom class definition/registration <a href="#registration_example">here</a>.</p> <p>Note that the <code>class_name</code> key contains the original name of the class, allowing for proper re-initialization in <code>from_config</code>.</p> <p>Additionally, note that the <code>module</code> key is <code>None</code> since this is a custom object.</p> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#save-serialize-and-export-models'>Save, serialize, and export models</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#how-to-save-and-load-a-model'>How to save and load a model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#setup'>Setup</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#saving'>Saving</a> </div> <div class='k-outline-depth-3'> <a href='#custom-objects'>Custom objects</a> </div> <div class='k-outline-depth-3'> <a href='#model-serialization'>Model serialization</a> </div> <div class='k-outline-depth-3'> <a href='#model-weights-saving'>Model weights saving</a> </div> <div class='k-outline-depth-3'> <a href='#appendix-handling-custom-objects'>Appendix: Handling custom objects</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>