CINXE.COM

Making new layers and models via subclassing

<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/guides/making_new_layers_and_models_via_subclassing/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Making new layers and models via subclassing"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Making new layers and models via subclassing"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Making new layers and models via subclassing</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link active" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-sublink" href="/guides/functional_api/">The Functional API</a> <a class="nav-sublink" href="/guides/sequential_model/">The Sequential model</a> <a class="nav-sublink active" href="/guides/making_new_layers_and_models_via_subclassing/">Making new layers & models via subclassing</a> <a class="nav-sublink" href="/guides/training_with_built_in_methods/">Training & evaluation with the built-in methods</a> <a class="nav-sublink" href="/guides/custom_train_step_in_jax/">Customizing `fit()` with JAX</a> <a class="nav-sublink" href="/guides/custom_train_step_in_tensorflow/">Customizing `fit()` with TensorFlow</a> <a class="nav-sublink" href="/guides/custom_train_step_in_torch/">Customizing `fit()` with PyTorch</a> <a class="nav-sublink" href="/guides/writing_a_custom_training_loop_in_jax/">Writing a custom training loop in JAX</a> <a class="nav-sublink" href="/guides/writing_a_custom_training_loop_in_tensorflow/">Writing a custom training loop in TensorFlow</a> <a class="nav-sublink" href="/guides/writing_a_custom_training_loop_in_torch/">Writing a custom training loop in PyTorch</a> <a class="nav-sublink" href="/guides/serialization_and_saving/">Serialization & saving</a> <a class="nav-sublink" href="/guides/customizing_saving_and_serialization/">Customizing saving & serialization</a> <a class="nav-sublink" href="/guides/writing_your_own_callbacks/">Writing your own callbacks</a> <a class="nav-sublink" href="/guides/transfer_learning/">Transfer learning & fine-tuning</a> <a class="nav-sublink" href="/guides/distributed_training_with_jax/">Distributed training with JAX</a> <a class="nav-sublink" href="/guides/distributed_training_with_tensorflow/">Distributed training with TensorFlow</a> <a class="nav-sublink" href="/guides/distributed_training_with_torch/">Distributed training with PyTorch</a> <a class="nav-sublink" href="/guides/distribution/">Distributed training with Keras 3</a> <a class="nav-sublink" href="/guides/migrating_to_keras_3/">Migrating Keras 2 code to Keras 3</a> <a class="nav-link" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparam Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/guides/'>Developer guides</a> / Making new layers and models via subclassing </div> <div class='k-content'> <h1 id="making-new-layers-and-models-via-subclassing">Making new layers and models via subclassing</h1> <p><strong>Author:</strong> <a href="https://twitter.com/fchollet">fchollet</a><br> <strong>Date created:</strong> 2019/03/01<br> <strong>Last modified:</strong> 2023/06/25<br> <strong>Description:</strong> Complete guide to writing <code>Layer</code> and <code>Model</code> objects from scratch.</p> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/guides/ipynb/making_new_layers_and_models_via_subclassing.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/guides/making_new_layers_and_models_via_subclassing.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="introduction">Introduction</h2> <p>This guide will cover everything you need to know to build your own subclassed layers and models. In particular, you'll learn about the following features:</p> <ul> <li>The <code>Layer</code> class</li> <li>The <code>add_weight()</code> method</li> <li>Trainable and non-trainable weights</li> <li>The <code>build()</code> method</li> <li>Making sure your layers can be used with any backend</li> <li>The <code>add_loss()</code> method</li> <li>The <code>training</code> argument in <code>call()</code></li> <li>The <code>mask</code> argument in <code>call()</code></li> <li>Making sure your layers can be serialized</li> </ul> <p>Let's dive in.</p> <hr /> <h2 id="setup">Setup</h2> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> <span class="kn">import</span> <span class="nn">keras</span> <span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">ops</span> <span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">layers</span> </code></pre></div> <hr /> <h2 id="the-layer-class-the-combination-of-state-weights-and-some-computation">The <code>Layer</code> class: the combination of state (weights) and some computation</h2> <p>One of the central abstractions in Keras is the <code>Layer</code> class. A layer encapsulates both a state (the layer's "weights") and a transformation from inputs to outputs (a "call", the layer's forward pass).</p> <p>Here's a densely-connected layer. It has two state variables: the variables <code>w</code> and <code>b</code>.</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">Linear</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">units</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="n">input_dim</span><span class="o">=</span><span class="mi">32</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">w</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">add_weight</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">input_dim</span><span class="p">,</span> <span class="n">units</span><span class="p">),</span> <span class="n">initializer</span><span class="o">=</span><span class="s2">&quot;random_normal&quot;</span><span class="p">,</span> <span class="n">trainable</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">b</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">add_weight</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">units</span><span class="p">,),</span> <span class="n">initializer</span><span class="o">=</span><span class="s2">&quot;zeros&quot;</span><span class="p">,</span> <span class="n">trainable</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="k">return</span> <span class="n">ops</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">w</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">b</span> </code></pre></div> <p>You would use a layer by calling it on some tensor input(s), much like a Python function.</p> <div class="codehilite"><pre><span></span><code><span class="n">x</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">))</span> <span class="n">linear_layer</span> <span class="o">=</span> <span class="n">Linear</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span> <span class="n">y</span> <span class="o">=</span> <span class="n">linear_layer</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="n">y</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>[[ 0.085416 -0.06821361 -0.00741937 -0.03429271] [ 0.085416 -0.06821361 -0.00741937 -0.03429271]] </code></pre></div> </div> <p>Note that the weights <code>w</code> and <code>b</code> are automatically tracked by the layer upon being set as layer attributes:</p> <div class="codehilite"><pre><span></span><code><span class="k">assert</span> <span class="n">linear_layer</span><span class="o">.</span><span class="n">weights</span> <span class="o">==</span> <span class="p">[</span><span class="n">linear_layer</span><span class="o">.</span><span class="n">w</span><span class="p">,</span> <span class="n">linear_layer</span><span class="o">.</span><span class="n">b</span><span class="p">]</span> </code></pre></div> <hr /> <h2 id="layers-can-have-nontrainable-weights">Layers can have non-trainable weights</h2> <p>Besides trainable weights, you can add non-trainable weights to a layer as well. Such weights are meant not to be taken into account during backpropagation, when you are training the layer.</p> <p>Here's how to add and use a non-trainable weight:</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">ComputeSum</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_dim</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">total</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">add_weight</span><span class="p">(</span> <span class="n">initializer</span><span class="o">=</span><span class="s2">&quot;zeros&quot;</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">input_dim</span><span class="p">,),</span> <span class="n">trainable</span><span class="o">=</span><span class="kc">False</span> <span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="bp">self</span><span class="o">.</span><span class="n">total</span><span class="o">.</span><span class="n">assign_add</span><span class="p">(</span><span class="n">ops</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">))</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">total</span> <span class="n">x</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">))</span> <span class="n">my_sum</span> <span class="o">=</span> <span class="n">ComputeSum</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span> <span class="n">y</span> <span class="o">=</span> <span class="n">my_sum</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="n">y</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span> <span class="n">y</span> <span class="o">=</span> <span class="n">my_sum</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="n">y</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>[2. 2.] [4. 4.] </code></pre></div> </div> <p>It's part of <code>layer.weights</code>, but it gets categorized as a non-trainable weight:</p> <div class="codehilite"><pre><span></span><code><span class="nb">print</span><span class="p">(</span><span class="s2">&quot;weights:&quot;</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">my_sum</span><span class="o">.</span><span class="n">weights</span><span class="p">))</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;non-trainable weights:&quot;</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">my_sum</span><span class="o">.</span><span class="n">non_trainable_weights</span><span class="p">))</span> <span class="c1"># It&#39;s not included in the trainable weights:</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;trainable_weights:&quot;</span><span class="p">,</span> <span class="n">my_sum</span><span class="o">.</span><span class="n">trainable_weights</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>weights: 1 non-trainable weights: 1 trainable_weights: [] </code></pre></div> </div> <hr /> <h2 id="best-practice-deferring-weight-creation-until-the-shape-of-the-inputs-is-known">Best practice: deferring weight creation until the shape of the inputs is known</h2> <p>Our <code>Linear</code> layer above took an <code>input_dim</code> argument that was used to compute the shape of the weights <code>w</code> and <code>b</code> in <code>__init__()</code>:</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">Linear</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">units</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="n">input_dim</span><span class="o">=</span><span class="mi">32</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">w</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">add_weight</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">input_dim</span><span class="p">,</span> <span class="n">units</span><span class="p">),</span> <span class="n">initializer</span><span class="o">=</span><span class="s2">&quot;random_normal&quot;</span><span class="p">,</span> <span class="n">trainable</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">b</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">add_weight</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">units</span><span class="p">,),</span> <span class="n">initializer</span><span class="o">=</span><span class="s2">&quot;zeros&quot;</span><span class="p">,</span> <span class="n">trainable</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="k">return</span> <span class="n">ops</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">w</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">b</span> </code></pre></div> <p>In many cases, you may not know in advance the size of your inputs, and you would like to lazily create weights when that value becomes known, some time after instantiating the layer.</p> <p>In the Keras API, we recommend creating layer weights in the <code>build(self, inputs_shape)</code> method of your layer. Like this:</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">Linear</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">units</span><span class="o">=</span><span class="mi">32</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">units</span> <span class="o">=</span> <span class="n">units</span> <span class="k">def</span> <span class="nf">build</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_shape</span><span class="p">):</span> <span class="bp">self</span><span class="o">.</span><span class="n">w</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">add_weight</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">input_shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">units</span><span class="p">),</span> <span class="n">initializer</span><span class="o">=</span><span class="s2">&quot;random_normal&quot;</span><span class="p">,</span> <span class="n">trainable</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">b</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">add_weight</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">units</span><span class="p">,),</span> <span class="n">initializer</span><span class="o">=</span><span class="s2">&quot;random_normal&quot;</span><span class="p">,</span> <span class="n">trainable</span><span class="o">=</span><span class="kc">True</span> <span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="k">return</span> <span class="n">ops</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">w</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">b</span> </code></pre></div> <p>The <code>__call__()</code> method of your layer will automatically run build the first time it is called. You now have a layer that's lazy and thus easier to use:</p> <div class="codehilite"><pre><span></span><code><span class="c1"># At instantiation, we don&#39;t know on what inputs this is going to get called</span> <span class="n">linear_layer</span> <span class="o">=</span> <span class="n">Linear</span><span class="p">(</span><span class="mi">32</span><span class="p">)</span> <span class="c1"># The layer&#39;s weights are created dynamically the first time the layer is called</span> <span class="n">y</span> <span class="o">=</span> <span class="n">linear_layer</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> </code></pre></div> <p>Implementing <code>build()</code> separately as shown above nicely separates creating weights only once from using weights in every call.</p> <hr /> <h2 id="layers-are-recursively-composable">Layers are recursively composable</h2> <p>If you assign a Layer instance as an attribute of another Layer, the outer layer will start tracking the weights created by the inner layer.</p> <p>We recommend creating such sublayers in the <code>__init__()</code> method and leave it to the first <code>__call__()</code> to trigger building their weights.</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">MLPBlock</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear_1</span> <span class="o">=</span> <span class="n">Linear</span><span class="p">(</span><span class="mi">32</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear_2</span> <span class="o">=</span> <span class="n">Linear</span><span class="p">(</span><span class="mi">32</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear_3</span> <span class="o">=</span> <span class="n">Linear</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear_1</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">activations</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear_2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">activations</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear_3</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">mlp</span> <span class="o">=</span> <span class="n">MLPBlock</span><span class="p">()</span> <span class="n">y</span> <span class="o">=</span> <span class="n">mlp</span><span class="p">(</span><span class="n">ops</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">64</span><span class="p">)))</span> <span class="c1"># The first call to the `mlp` will create the weights</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;weights:&quot;</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">mlp</span><span class="o">.</span><span class="n">weights</span><span class="p">))</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;trainable weights:&quot;</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">mlp</span><span class="o">.</span><span class="n">trainable_weights</span><span class="p">))</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>weights: 6 trainable weights: 6 </code></pre></div> </div> <hr /> <h2 id="backendagnostic-layers-and-backendspecific-layers">Backend-agnostic layers and backend-specific layers</h2> <p>As long as a layer only uses APIs from the <code>keras.ops</code> namespace (or other Keras namespaces such as <code>keras.activations</code>, <code>keras.random</code>, or <code>keras.layers</code>), then it can be used with any backend &ndash; TensorFlow, JAX, or PyTorch.</p> <p>All layers you've seen so far in this guide work with all Keras backends.</p> <p>The <code>keras.ops</code> namespace gives you access to:</p> <ul> <li>The NumPy API, e.g. <code>ops.matmul</code>, <code>ops.sum</code>, <code>ops.reshape</code>, <code>ops.stack</code>, etc.</li> <li>Neural networks-specific APIs such as <code>ops.softmax</code>, <code>ops</code>.conv<code>,</code>ops.binary_crossentropy<code>,</code>ops.relu`, etc.</li> </ul> <p>You can also use backend-native APIs in your layers (such as <a href="https://www.tensorflow.org/api_docs/python/tf/nn"><code>tf.nn</code></a> functions), but if you do this, then your layer will only be usable with the backend in question. For instance, you could write the following JAX-specific layer using <code>jax.numpy</code>:</p> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">jax</span> <span class="k">class</span> <span class="nc">Linear</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="o">...</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="k">return</span> <span class="n">jax</span><span class="o">.</span><span class="n">numpy</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">w</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">b</span> </code></pre></div> <p>This would be the equivalent TensorFlow-specific layer:</p> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="nn">tf</span> <span class="k">class</span> <span class="nc">Linear</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="o">...</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="k">return</span> <span class="n">tf</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">w</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">b</span> </code></pre></div> <p>And this would be the equivalent PyTorch-specific layer:</p> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">torch</span> <span class="k">class</span> <span class="nc">Linear</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="o">...</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="k">return</span> <span class="n">torch</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">w</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">b</span> </code></pre></div> <p>Because cross-backend compatibility is a tremendously useful property, we strongly recommend that you seek to always make your layers backend-agnostic by leveraging only Keras APIs.</p> <hr /> <h2 id="the-addloss-method">The <code>add_loss()</code> method</h2> <p>When writing the <code>call()</code> method of a layer, you can create loss tensors that you will want to use later, when writing your training loop. This is doable by calling <code>self.add_loss(value)</code>:</p> <div class="codehilite"><pre><span></span><code><span class="c1"># A layer that creates an activity regularization loss</span> <span class="k">class</span> <span class="nc">ActivityRegularizationLayer</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">rate</span><span class="o">=</span><span class="mf">1e-2</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">rate</span> <span class="o">=</span> <span class="n">rate</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="bp">self</span><span class="o">.</span><span class="n">add_loss</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">rate</span> <span class="o">*</span> <span class="n">ops</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">inputs</span><span class="p">))</span> <span class="k">return</span> <span class="n">inputs</span> </code></pre></div> <p>These losses (including those created by any inner layer) can be retrieved via <code>layer.losses</code>. This property is reset at the start of every <code>__call__()</code> to the top-level layer, so that <code>layer.losses</code> always contains the loss values created during the last forward pass.</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">OuterLayer</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">activity_reg</span> <span class="o">=</span> <span class="n">ActivityRegularizationLayer</span><span class="p">(</span><span class="mf">1e-2</span><span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">activity_reg</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">layer</span> <span class="o">=</span> <span class="n">OuterLayer</span><span class="p">()</span> <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">losses</span><span class="p">)</span> <span class="o">==</span> <span class="mi">0</span> <span class="c1"># No losses yet since the layer has never been called</span> <span class="n">_</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">ops</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)))</span> <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">losses</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span> <span class="c1"># We created one loss value</span> <span class="c1"># `layer.losses` gets reset at the start of each __call__</span> <span class="n">_</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">ops</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)))</span> <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">losses</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span> <span class="c1"># This is the loss created during the call above</span> </code></pre></div> <p>In addition, the <code>loss</code> property also contains regularization losses created for the weights of any inner layer:</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">OuterLayerWithKernelRegularizer</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span> <span class="mi">32</span><span class="p">,</span> <span class="n">kernel_regularizer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">regularizers</span><span class="o">.</span><span class="n">l2</span><span class="p">(</span><span class="mf">1e-3</span><span class="p">)</span> <span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">layer</span> <span class="o">=</span> <span class="n">OuterLayerWithKernelRegularizer</span><span class="p">()</span> <span class="n">_</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">ops</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)))</span> <span class="c1"># This is `1e-3 * sum(layer.dense.kernel ** 2)`,</span> <span class="c1"># created by the `kernel_regularizer` above.</span> <span class="nb">print</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">losses</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>[Array(0.00217911, dtype=float32)] </code></pre></div> </div> <p>These losses are meant to be taken into account when writing custom training loops.</p> <p>They also work seamlessly with <code>fit()</code> (they get automatically summed and added to the main loss, if any):</p> <div class="codehilite"><pre><span></span><code><span class="n">inputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,))</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">ActivityRegularizationLayer</span><span class="p">()(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="p">)</span> <span class="c1"># If there is a loss passed in `compile`, the regularization</span> <span class="c1"># losses get added to it</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">optimizer</span><span class="o">=</span><span class="s2">&quot;adam&quot;</span><span class="p">,</span> <span class="n">loss</span><span class="o">=</span><span class="s2">&quot;mse&quot;</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">)),</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">)))</span> <span class="c1"># It&#39;s also possible not to pass any loss in `compile`,</span> <span class="c1"># since the model already has a loss to minimize, via the `add_loss`</span> <span class="c1"># call during the forward pass!</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">optimizer</span><span class="o">=</span><span class="s2">&quot;adam&quot;</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">)),</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">)))</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 60ms/step - loss: 0.2650 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - loss: 0.0050 &lt;keras.src.callbacks.history.History at 0x146f71960&gt; </code></pre></div> </div> <hr /> <h2 id="you-can-optionally-enable-serialization-on-your-layers">You can optionally enable serialization on your layers</h2> <p>If you need your custom layers to be serializable as part of a <a href="/guides/functional_api/">Functional model</a>, you can optionally implement a <code>get_config()</code> method:</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">Linear</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">units</span><span class="o">=</span><span class="mi">32</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">units</span> <span class="o">=</span> <span class="n">units</span> <span class="k">def</span> <span class="nf">build</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_shape</span><span class="p">):</span> <span class="bp">self</span><span class="o">.</span><span class="n">w</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">add_weight</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">input_shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">units</span><span class="p">),</span> <span class="n">initializer</span><span class="o">=</span><span class="s2">&quot;random_normal&quot;</span><span class="p">,</span> <span class="n">trainable</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">b</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">add_weight</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">units</span><span class="p">,),</span> <span class="n">initializer</span><span class="o">=</span><span class="s2">&quot;random_normal&quot;</span><span class="p">,</span> <span class="n">trainable</span><span class="o">=</span><span class="kc">True</span> <span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="k">return</span> <span class="n">ops</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">w</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">b</span> <span class="k">def</span> <span class="nf">get_config</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="k">return</span> <span class="p">{</span><span class="s2">&quot;units&quot;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">units</span><span class="p">}</span> <span class="c1"># Now you can recreate the layer from its config:</span> <span class="n">layer</span> <span class="o">=</span> <span class="n">Linear</span><span class="p">(</span><span class="mi">64</span><span class="p">)</span> <span class="n">config</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">get_config</span><span class="p">()</span> <span class="nb">print</span><span class="p">(</span><span class="n">config</span><span class="p">)</span> <span class="n">new_layer</span> <span class="o">=</span> <span class="n">Linear</span><span class="o">.</span><span class="n">from_config</span><span class="p">(</span><span class="n">config</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>{&#39;units&#39;: 64} </code></pre></div> </div> <p>Note that the <code>__init__()</code> method of the base <code>Layer</code> class takes some keyword arguments, in particular a <code>name</code> and a <code>dtype</code>. It's good practice to pass these arguments to the parent class in <code>__init__()</code> and to include them in the layer config:</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">Linear</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">units</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">units</span> <span class="o">=</span> <span class="n">units</span> <span class="k">def</span> <span class="nf">build</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_shape</span><span class="p">):</span> <span class="bp">self</span><span class="o">.</span><span class="n">w</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">add_weight</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">input_shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">units</span><span class="p">),</span> <span class="n">initializer</span><span class="o">=</span><span class="s2">&quot;random_normal&quot;</span><span class="p">,</span> <span class="n">trainable</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">b</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">add_weight</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">units</span><span class="p">,),</span> <span class="n">initializer</span><span class="o">=</span><span class="s2">&quot;random_normal&quot;</span><span class="p">,</span> <span class="n">trainable</span><span class="o">=</span><span class="kc">True</span> <span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="k">return</span> <span class="n">ops</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">w</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">b</span> <span class="k">def</span> <span class="nf">get_config</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="n">config</span> <span class="o">=</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">get_config</span><span class="p">()</span> <span class="n">config</span><span class="o">.</span><span class="n">update</span><span class="p">({</span><span class="s2">&quot;units&quot;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">units</span><span class="p">})</span> <span class="k">return</span> <span class="n">config</span> <span class="n">layer</span> <span class="o">=</span> <span class="n">Linear</span><span class="p">(</span><span class="mi">64</span><span class="p">)</span> <span class="n">config</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">get_config</span><span class="p">()</span> <span class="nb">print</span><span class="p">(</span><span class="n">config</span><span class="p">)</span> <span class="n">new_layer</span> <span class="o">=</span> <span class="n">Linear</span><span class="o">.</span><span class="n">from_config</span><span class="p">(</span><span class="n">config</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>{&#39;name&#39;: &#39;linear_7&#39;, &#39;trainable&#39;: True, &#39;dtype&#39;: &#39;float32&#39;, &#39;units&#39;: 64} </code></pre></div> </div> <p>If you need more flexibility when deserializing the layer from its config, you can also override the <code>from_config()</code> class method. This is the base implementation of <code>from_config()</code>:</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">from_config</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">config</span><span class="p">):</span> <span class="k">return</span> <span class="bp">cls</span><span class="p">(</span><span class="o">**</span><span class="n">config</span><span class="p">)</span> </code></pre></div> <p>To learn more about serialization and saving, see the complete <a href="/guides/serialization_and_saving/">guide to saving and serializing models</a>.</p> <hr /> <h2 id="privileged-training-argument-in-the-call-method">Privileged <code>training</code> argument in the <code>call()</code> method</h2> <p>Some layers, in particular the <code>BatchNormalization</code> layer and the <code>Dropout</code> layer, have different behaviors during training and inference. For such layers, it is standard practice to expose a <code>training</code> (boolean) argument in the <code>call()</code> method.</p> <p>By exposing this argument in <code>call()</code>, you enable the built-in training and evaluation loops (e.g. <code>fit()</code>) to correctly use the layer in training and inference.</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">CustomDropout</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">rate</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">rate</span> <span class="o">=</span> <span class="n">rate</span> <span class="bp">self</span><span class="o">.</span><span class="n">seed_generator</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">SeedGenerator</span><span class="p">(</span><span class="mi">1337</span><span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="k">if</span> <span class="n">training</span><span class="p">:</span> <span class="k">return</span> <span class="n">keras</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">rate</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">rate</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">seed_generator</span> <span class="p">)</span> <span class="k">return</span> <span class="n">inputs</span> </code></pre></div> <hr /> <h2 id="privileged-mask-argument-in-the-call-method">Privileged <code>mask</code> argument in the <code>call()</code> method</h2> <p>The other privileged argument supported by <code>call()</code> is the <code>mask</code> argument.</p> <p>You will find it in all Keras RNN layers. A mask is a boolean tensor (one boolean value per timestep in the input) used to skip certain input timesteps when processing timeseries data.</p> <p>Keras will automatically pass the correct <code>mask</code> argument to <code>__call__()</code> for layers that support it, when a mask is generated by a prior layer. Mask-generating layers are the <code>Embedding</code> layer configured with <code>mask_zero=True</code>, and the <code>Masking</code> layer.</p> <hr /> <h2 id="the-model-class">The <code>Model</code> class</h2> <p>In general, you will use the <code>Layer</code> class to define inner computation blocks, and will use the <code>Model</code> class to define the outer model &ndash; the object you will train.</p> <p>For instance, in a ResNet50 model, you would have several ResNet blocks subclassing <code>Layer</code>, and a single <code>Model</code> encompassing the entire ResNet50 network.</p> <p>The <code>Model</code> class has the same API as <code>Layer</code>, with the following differences:</p> <ul> <li>It exposes built-in training, evaluation, and prediction loops (<code>model.fit()</code>, <code>model.evaluate()</code>, <code>model.predict()</code>).</li> <li>It exposes the list of its inner layers, via the <code>model.layers</code> property.</li> <li>It exposes saving and serialization APIs (<code>save()</code>, <code>save_weights()</code>...)</li> </ul> <p>Effectively, the <code>Layer</code> class corresponds to what we refer to in the literature as a "layer" (as in "convolution layer" or "recurrent layer") or as a "block" (as in "ResNet block" or "Inception block").</p> <p>Meanwhile, the <code>Model</code> class corresponds to what is referred to in the literature as a "model" (as in "deep learning model") or as a "network" (as in "deep neural network").</p> <p>So if you're wondering, "should I use the <code>Layer</code> class or the <code>Model</code> class?", ask yourself: will I need to call <code>fit()</code> on it? Will I need to call <code>save()</code> on it? If so, go with <code>Model</code>. If not (either because your class is just a block in a bigger system, or because you are writing training &amp; saving code yourself), use <code>Layer</code>.</p> <p>For instance, we could take our mini-resnet example above, and use it to build a <code>Model</code> that we could train with <code>fit()</code>, and that we could save with <code>save_weights()</code>:</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">ResNet</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">num_classes</span><span class="o">=</span><span class="mi">1000</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">block_1</span> <span class="o">=</span> <span class="n">ResNetBlock</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">block_2</span> <span class="o">=</span> <span class="n">ResNetBlock</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">global_pool</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">GlobalAveragePooling2D</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">classifier</span> <span class="o">=</span> <span class="n">Dense</span><span class="p">(</span><span class="n">num_classes</span><span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">block_1</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">block_2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">global_pool</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">classifier</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">resnet</span> <span class="o">=</span> <span class="n">ResNet</span><span class="p">()</span> <span class="n">dataset</span> <span class="o">=</span> <span class="o">...</span> <span class="n">resnet</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span> <span class="n">resnet</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="n">filepath</span><span class="o">.</span><span class="n">keras</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="putting-it-all-together-an-endtoend-example">Putting it all together: an end-to-end example</h2> <p>Here's what you've learned so far:</p> <ul> <li>A <code>Layer</code> encapsulate a state (created in <code>__init__()</code> or <code>build()</code>) and some computation (defined in <code>call()</code>).</li> <li>Layers can be recursively nested to create new, bigger computation blocks.</li> <li>Layers are backend-agnostic as long as they only use Keras APIs. You can use backend-native APIs (such as <code>jax.numpy</code>, <code>torch.nn</code> or <a href="https://www.tensorflow.org/api_docs/python/tf/nn"><code>tf.nn</code></a>), but then your layer will only be usable with that specific backend.</li> <li>Layers can create and track losses (typically regularization losses) via <code>add_loss()</code>.</li> <li>The outer container, the thing you want to train, is a <code>Model</code>. A <code>Model</code> is just like a <code>Layer</code>, but with added training and serialization utilities.</li> </ul> <p>Let's put all of these things together into an end-to-end example: we're going to implement a Variational AutoEncoder (VAE) in a backend-agnostic fashion &ndash; so that it runs the same with TensorFlow, JAX, and PyTorch. We'll train it on MNIST digits.</p> <p>Our VAE will be a subclass of <code>Model</code>, built as a nested composition of layers that subclass <code>Layer</code>. It will feature a regularization loss (KL divergence).</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">Sampling</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="w"> </span><span class="sd">&quot;&quot;&quot;Uses (z_mean, z_log_var) to sample z, the vector encoding a digit.&quot;&quot;&quot;</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">seed_generator</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">SeedGenerator</span><span class="p">(</span><span class="mi">1337</span><span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="n">z_mean</span><span class="p">,</span> <span class="n">z_log_var</span> <span class="o">=</span> <span class="n">inputs</span> <span class="n">batch</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">z_mean</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span> <span class="n">dim</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">z_mean</span><span class="p">)[</span><span class="mi">1</span><span class="p">]</span> <span class="n">epsilon</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">batch</span><span class="p">,</span> <span class="n">dim</span><span class="p">),</span> <span class="n">seed</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">seed_generator</span><span class="p">)</span> <span class="k">return</span> <span class="n">z_mean</span> <span class="o">+</span> <span class="n">ops</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="mf">0.5</span> <span class="o">*</span> <span class="n">z_log_var</span><span class="p">)</span> <span class="o">*</span> <span class="n">epsilon</span> <span class="k">class</span> <span class="nc">Encoder</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="w"> </span><span class="sd">&quot;&quot;&quot;Maps MNIST digits to a triplet (z_mean, z_log_var, z).&quot;&quot;&quot;</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">latent_dim</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="n">intermediate_dim</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;encoder&quot;</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense_proj</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">intermediate_dim</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">&quot;relu&quot;</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense_mean</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">latent_dim</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense_log_var</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">latent_dim</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">sampling</span> <span class="o">=</span> <span class="n">Sampling</span><span class="p">()</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense_proj</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">z_mean</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense_mean</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">z_log_var</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense_log_var</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sampling</span><span class="p">((</span><span class="n">z_mean</span><span class="p">,</span> <span class="n">z_log_var</span><span class="p">))</span> <span class="k">return</span> <span class="n">z_mean</span><span class="p">,</span> <span class="n">z_log_var</span><span class="p">,</span> <span class="n">z</span> <span class="k">class</span> <span class="nc">Decoder</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="w"> </span><span class="sd">&quot;&quot;&quot;Converts z, the encoded digit vector, back into a readable digit.&quot;&quot;&quot;</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">original_dim</span><span class="p">,</span> <span class="n">intermediate_dim</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;decoder&quot;</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense_proj</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">intermediate_dim</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">&quot;relu&quot;</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense_output</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">original_dim</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">&quot;sigmoid&quot;</span><span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense_proj</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense_output</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">class</span> <span class="nc">VariationalAutoEncoder</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">):</span> <span class="w"> </span><span class="sd">&quot;&quot;&quot;Combines the encoder and decoder into an end-to-end model for training.&quot;&quot;&quot;</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> <span class="bp">self</span><span class="p">,</span> <span class="n">original_dim</span><span class="p">,</span> <span class="n">intermediate_dim</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">latent_dim</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;autoencoder&quot;</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span> <span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">original_dim</span> <span class="o">=</span> <span class="n">original_dim</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span> <span class="o">=</span> <span class="n">Encoder</span><span class="p">(</span><span class="n">latent_dim</span><span class="o">=</span><span class="n">latent_dim</span><span class="p">,</span> <span class="n">intermediate_dim</span><span class="o">=</span><span class="n">intermediate_dim</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span> <span class="o">=</span> <span class="n">Decoder</span><span class="p">(</span><span class="n">original_dim</span><span class="p">,</span> <span class="n">intermediate_dim</span><span class="o">=</span><span class="n">intermediate_dim</span><span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="n">z_mean</span><span class="p">,</span> <span class="n">z_log_var</span><span class="p">,</span> <span class="n">z</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">reconstructed</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span><span class="p">(</span><span class="n">z</span><span class="p">)</span> <span class="c1"># Add KL divergence regularization loss.</span> <span class="n">kl_loss</span> <span class="o">=</span> <span class="o">-</span><span class="mf">0.5</span> <span class="o">*</span> <span class="n">ops</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span> <span class="n">z_log_var</span> <span class="o">-</span> <span class="n">ops</span><span class="o">.</span><span class="n">square</span><span class="p">(</span><span class="n">z_mean</span><span class="p">)</span> <span class="o">-</span> <span class="n">ops</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">z_log_var</span><span class="p">)</span> <span class="o">+</span> <span class="mi">1</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">add_loss</span><span class="p">(</span><span class="n">kl_loss</span><span class="p">)</span> <span class="k">return</span> <span class="n">reconstructed</span> </code></pre></div> <p>Let's train it on MNIST using the <code>fit()</code> API:</p> <div class="codehilite"><pre><span></span><code><span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="n">_</span><span class="p">),</span> <span class="n">_</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">datasets</span><span class="o">.</span><span class="n">mnist</span><span class="o">.</span><span class="n">load_data</span><span class="p">()</span> <span class="n">x_train</span> <span class="o">=</span> <span class="n">x_train</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">60000</span><span class="p">,</span> <span class="mi">784</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">&quot;float32&quot;</span><span class="p">)</span> <span class="o">/</span> <span class="mi">255</span> <span class="n">original_dim</span> <span class="o">=</span> <span class="mi">784</span> <span class="n">vae</span> <span class="o">=</span> <span class="n">VariationalAutoEncoder</span><span class="p">(</span><span class="mi">784</span><span class="p">,</span> <span class="mi">64</span><span class="p">,</span> <span class="mi">32</span><span class="p">)</span> <span class="n">optimizer</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">learning_rate</span><span class="o">=</span><span class="mf">1e-3</span><span class="p">)</span> <span class="n">vae</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">optimizer</span><span class="p">,</span> <span class="n">loss</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">MeanSquaredError</span><span class="p">())</span> <span class="n">vae</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="n">x_train</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">64</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 1/2 938/938 ━━━━━━━━━━━━━━━━━━━━ 2s 1ms/step - loss: 0.0942 Epoch 2/2 938/938 ━━━━━━━━━━━━━━━━━━━━ 1s 859us/step - loss: 0.0677 &lt;keras.src.callbacks.history.History at 0x146fe62f0&gt; </code></pre></div> </div> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#making-new-layers-and-models-via-subclassing'>Making new layers and models via subclassing</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#setup'>Setup</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#the-layer-class-the-combination-of-state-weights-and-some-computation'>The <code>Layer</code> class: the combination of state (weights) and some computation</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#layers-can-have-nontrainable-weights'>Layers can have non-trainable weights</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#best-practice-deferring-weight-creation-until-the-shape-of-the-inputs-is-known'>Best practice: deferring weight creation until the shape of the inputs is known</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#layers-are-recursively-composable'>Layers are recursively composable</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#backendagnostic-layers-and-backendspecific-layers'>Backend-agnostic layers and backend-specific layers</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#the-addloss-method'>The <code>add_loss()</code> method</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#you-can-optionally-enable-serialization-on-your-layers'>You can optionally enable serialization on your layers</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#privileged-training-argument-in-the-call-method'>Privileged <code>training</code> argument in the <code>call()</code> method</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#privileged-mask-argument-in-the-call-method'>Privileged <code>mask</code> argument in the <code>call()</code> method</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#the-model-class'>The <code>Model</code> class</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#putting-it-all-together-an-endtoend-example'>Putting it all together: an end-to-end example</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>

Pages: 1 2 3 4 5 6 7 8 9 10