CINXE.COM
Keras debugging tips
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/keras_recipes/debugging_tips/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Keras debugging tips"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Keras debugging tips"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Keras debugging tips</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink active" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-sublink2" href="/examples/keras_recipes/parameter_efficient_finetuning_of_gemma_with_lora_and_qlora/">Parameter-efficient fine-tuning of Gemma with LoRA and QLoRA</a> <a class="nav-sublink2" href="/examples/keras_recipes/float8_training_and_inference_with_transformer/">Float8 training and inference with a simple Transformer model</a> <a class="nav-sublink2" href="/examples/keras_recipes/tf_serving/">Serving TensorFlow models with TFServing</a> <a class="nav-sublink2 active" href="/examples/keras_recipes/debugging_tips/">Keras debugging tips</a> <a class="nav-sublink2" href="/examples/keras_recipes/subclassing_conv_layers/">Customizing the convolution operation of a Conv2D layer</a> <a class="nav-sublink2" href="/examples/keras_recipes/trainer_pattern/">Trainer pattern</a> <a class="nav-sublink2" href="/examples/keras_recipes/endpoint_layer_pattern/">Endpoint layer pattern</a> <a class="nav-sublink2" href="/examples/keras_recipes/reproducibility_recipes/">Reproducibility in Keras Models</a> <a class="nav-sublink2" href="/examples/keras_recipes/tensorflow_numpy_models/">Writing Keras Models With TensorFlow NumPy</a> <a class="nav-sublink2" href="/examples/keras_recipes/antirectifier/">Simple custom layer example: Antirectifier</a> <a class="nav-sublink2" href="/examples/keras_recipes/sample_size_estimate/">Estimating required sample size for model training</a> <a class="nav-sublink2" href="/examples/keras_recipes/memory_efficient_embeddings/">Memory-efficient embeddings for recommendation systems</a> <a class="nav-sublink2" href="/examples/keras_recipes/creating_tfrecords/">Creating TFRecords</a> <a class="nav-sublink2" href="/examples/keras_recipes/packaging_keras_models_for_wide_distribution/">Packaging Keras models for wide distribution using Functional Subclassing</a> <a class="nav-sublink2" href="/examples/keras_recipes/approximating_non_function_mappings/">Approximating non-Function Mappings with Mixture Density Networks</a> <a class="nav-sublink2" href="/examples/keras_recipes/bayesian_neural_networks/">Probabilistic Bayesian Neural Networks</a> <a class="nav-sublink2" href="/examples/keras_recipes/better_knowledge_distillation/">Knowledge distillation recipes</a> <a class="nav-sublink2" href="/examples/keras_recipes/sklearn_metric_callbacks/">Evaluating and exporting scikit-learn metrics in a Keras callback</a> <a class="nav-sublink2" href="/examples/keras_recipes/tfrecord/">How to train a Keras model on TFRecord files</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparameter Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> <a class="nav-link" href="/keras_cv/" role="tab" aria-selected="">KerasCV: Computer Vision Workflows</a> <a class="nav-link" href="/keras_nlp/" role="tab" aria-selected="">KerasNLP: Natural Language Workflows</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/keras_recipes/'>Quick Keras Recipes</a> / Keras debugging tips </div> <div class='k-content'> <h1 id="keras-debugging-tips">Keras debugging tips</h1> <p><strong>Author:</strong> <a href="https://twitter.com/fchollet">fchollet</a><br> <strong>Date created:</strong> 2020/05/16<br> <strong>Last modified:</strong> 2023/11/16<br> <strong>Description:</strong> Four simple tips to help you debug your Keras code.</p> <div class='example_version_banner keras_3'>ⓘ This example uses Keras 3</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/keras_recipes/ipynb/debugging_tips.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/examples/keras_recipes/debugging_tips.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="introduction">Introduction</h2> <p>It's generally possible to do almost anything in Keras <em>without writing code</em> per se: whether you're implementing a new type of GAN or the latest convnet architecture for image segmentation, you can usually stick to calling built-in methods. Because all built-in methods do extensive input validation checks, you will have little to no debugging to do. A Functional API model made entirely of built-in layers will work on first try – if you can compile it, it will run.</p> <p>However, sometimes, you will need to dive deeper and write your own code. Here are some common examples:</p> <ul> <li>Creating a new <code>Layer</code> subclass.</li> <li>Creating a custom <code>Metric</code> subclass.</li> <li>Implementing a custom <code>train_step</code> on a <code>Model</code>.</li> </ul> <p>This document provides a few simple tips to help you navigate debugging in these situations.</p> <hr /> <h2 id="tip-1-test-each-part-before-you-test-the-whole">Tip 1: test each part before you test the whole</h2> <p>If you've created any object that has a chance of not working as expected, don't just drop it in your end-to-end process and watch sparks fly. Rather, test your custom object in isolation first. This may seem obvious – but you'd be surprised how often people don't start with this.</p> <ul> <li>If you write a custom layer, don't call <code>fit()</code> on your entire model just yet. Call your layer on some test data first.</li> <li>If you write a custom metric, start by printing its output for some reference inputs.</li> </ul> <p>Here's a simple example. Let's write a custom layer a bug in it:</p> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">os</span> <span class="c1"># The last example uses tf.GradientTape and thus requires TensorFlow.</span> <span class="c1"># However, all tips here are applicable with all backends.</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s2">"KERAS_BACKEND"</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"tensorflow"</span> <span class="kn">import</span> <span class="nn">keras</span> <span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">layers</span> <span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">ops</span> <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> <span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="nn">tf</span> <span class="k">class</span> <span class="nc">MyAntirectifier</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="nf">build</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_shape</span><span class="p">):</span> <span class="n">output_dim</span> <span class="o">=</span> <span class="n">input_shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="bp">self</span><span class="o">.</span><span class="n">kernel</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">add_weight</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">output_dim</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span> <span class="n">output_dim</span><span class="p">),</span> <span class="n">initializer</span><span class="o">=</span><span class="s2">"he_normal"</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"kernel"</span><span class="p">,</span> <span class="n">trainable</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="c1"># Take the positive part of the input</span> <span class="n">pos</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="c1"># Take the negative part of the input</span> <span class="n">neg</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="o">-</span><span class="n">inputs</span><span class="p">)</span> <span class="c1"># Concatenate the positive and negative parts</span> <span class="n">concatenated</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">concatenate</span><span class="p">([</span><span class="n">pos</span><span class="p">,</span> <span class="n">neg</span><span class="p">],</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="c1"># Project the concatenation down to the same dimensionality as the input</span> <span class="k">return</span> <span class="n">ops</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">concatenated</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">kernel</span><span class="p">)</span> </code></pre></div> <p>Now, rather than using it in a end-to-end model directly, let's try to call the layer on some test data:</p> <div class="codehilite"><pre><span></span><code><span class="n">x</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">5</span><span class="p">))</span> <span class="n">y</span> <span class="o">=</span> <span class="n">MyAntirectifier</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> </code></pre></div> <p>We get the following error:</p> <div class="codehilite"><pre><span></span><code>... 1 x = tf.random.normal(shape=(2, 5)) ----> 2 y = MyAntirectifier()(x) ... 17 neg = tf.nn.relu(-inputs) 18 concatenated = tf.concat([pos, neg], axis=0) ---> 19 return tf.matmul(concatenated, self.kernel) ... InvalidArgumentError: Matrix size-incompatible: In[0]: [4,5], In[1]: [10,5] [Op:MatMul] </code></pre></div> <p>Looks like our input tensor in the <code>matmul</code> op may have an incorrect shape. Let's add a print statement to check the actual shapes:</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">MyAntirectifier</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="nf">build</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_shape</span><span class="p">):</span> <span class="n">output_dim</span> <span class="o">=</span> <span class="n">input_shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="bp">self</span><span class="o">.</span><span class="n">kernel</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">add_weight</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">output_dim</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span> <span class="n">output_dim</span><span class="p">),</span> <span class="n">initializer</span><span class="o">=</span><span class="s2">"he_normal"</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"kernel"</span><span class="p">,</span> <span class="n">trainable</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="n">pos</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">neg</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="o">-</span><span class="n">inputs</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"pos.shape:"</span><span class="p">,</span> <span class="n">pos</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"neg.shape:"</span><span class="p">,</span> <span class="n">neg</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="n">concatenated</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">concatenate</span><span class="p">([</span><span class="n">pos</span><span class="p">,</span> <span class="n">neg</span><span class="p">],</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"concatenated.shape:"</span><span class="p">,</span> <span class="n">concatenated</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"kernel.shape:"</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">kernel</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="k">return</span> <span class="n">ops</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">concatenated</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">kernel</span><span class="p">)</span> </code></pre></div> <p>We get the following:</p> <div class="codehilite"><pre><span></span><code>pos.shape: (2, 5) neg.shape: (2, 5) concatenated.shape: (4, 5) kernel.shape: (10, 5) </code></pre></div> <p>Turns out we had the wrong axis for the <code>concat</code> op! We should be concatenating <code>neg</code> and <code>pos</code> alongside the feature axis 1, not the batch axis 0. Here's the correct version:</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">MyAntirectifier</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="nf">build</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_shape</span><span class="p">):</span> <span class="n">output_dim</span> <span class="o">=</span> <span class="n">input_shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="bp">self</span><span class="o">.</span><span class="n">kernel</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">add_weight</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">output_dim</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span> <span class="n">output_dim</span><span class="p">),</span> <span class="n">initializer</span><span class="o">=</span><span class="s2">"he_normal"</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"kernel"</span><span class="p">,</span> <span class="n">trainable</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="n">pos</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">neg</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="o">-</span><span class="n">inputs</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"pos.shape:"</span><span class="p">,</span> <span class="n">pos</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"neg.shape:"</span><span class="p">,</span> <span class="n">neg</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="n">concatenated</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">concatenate</span><span class="p">([</span><span class="n">pos</span><span class="p">,</span> <span class="n">neg</span><span class="p">],</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"concatenated.shape:"</span><span class="p">,</span> <span class="n">concatenated</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"kernel.shape:"</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">kernel</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="k">return</span> <span class="n">ops</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">concatenated</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">kernel</span><span class="p">)</span> </code></pre></div> <p>Now our code works fine:</p> <div class="codehilite"><pre><span></span><code><span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">5</span><span class="p">))</span> <span class="n">y</span> <span class="o">=</span> <span class="n">MyAntirectifier</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>pos.shape: (2, 5) neg.shape: (2, 5) concatenated.shape: (2, 10) kernel.shape: (10, 5) </code></pre></div> </div> <hr /> <h2 id="tip-2-use-modelsummary-and-plotmodel-to-check-layer-output-shapes">Tip 2: use <code>model.summary()</code> and <code>plot_model()</code> to check layer output shapes</h2> <p>If you're working with complex network topologies, you're going to need a way to visualize how your layers are connected and how they transform the data that passes through them.</p> <p>Here's an example. Consider this model with three inputs and two outputs (lifted from the <a href="https://keras.io/guides/functional_api/#manipulate-complex-graph-topologies">Functional API guide</a>):</p> <div class="codehilite"><pre><span></span><code><span class="n">num_tags</span> <span class="o">=</span> <span class="mi">12</span> <span class="c1"># Number of unique issue tags</span> <span class="n">num_words</span> <span class="o">=</span> <span class="mi">10000</span> <span class="c1"># Size of vocabulary obtained when preprocessing text data</span> <span class="n">num_departments</span> <span class="o">=</span> <span class="mi">4</span> <span class="c1"># Number of departments for predictions</span> <span class="n">title_input</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="kc">None</span><span class="p">,),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"title"</span> <span class="p">)</span> <span class="c1"># Variable-length sequence of ints</span> <span class="n">body_input</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="kc">None</span><span class="p">,),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"body"</span><span class="p">)</span> <span class="c1"># Variable-length sequence of ints</span> <span class="n">tags_input</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">num_tags</span><span class="p">,),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"tags"</span> <span class="p">)</span> <span class="c1"># Binary vectors of size `num_tags`</span> <span class="c1"># Embed each word in the title into a 64-dimensional vector</span> <span class="n">title_features</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="n">num_words</span><span class="p">,</span> <span class="mi">64</span><span class="p">)(</span><span class="n">title_input</span><span class="p">)</span> <span class="c1"># Embed each word in the text into a 64-dimensional vector</span> <span class="n">body_features</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="n">num_words</span><span class="p">,</span> <span class="mi">64</span><span class="p">)(</span><span class="n">body_input</span><span class="p">)</span> <span class="c1"># Reduce sequence of embedded words in the title into a single 128-dimensional vector</span> <span class="n">title_features</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">LSTM</span><span class="p">(</span><span class="mi">128</span><span class="p">)(</span><span class="n">title_features</span><span class="p">)</span> <span class="c1"># Reduce sequence of embedded words in the body into a single 32-dimensional vector</span> <span class="n">body_features</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">LSTM</span><span class="p">(</span><span class="mi">32</span><span class="p">)(</span><span class="n">body_features</span><span class="p">)</span> <span class="c1"># Merge all available features into a single large vector via concatenation</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">concatenate</span><span class="p">([</span><span class="n">title_features</span><span class="p">,</span> <span class="n">body_features</span><span class="p">,</span> <span class="n">tags_input</span><span class="p">])</span> <span class="c1"># Stick a logistic regression for priority prediction on top of the features</span> <span class="n">priority_pred</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"priority"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># Stick a department classifier on top of the features</span> <span class="n">department_pred</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">num_departments</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"department"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># Instantiate an end-to-end model predicting both priority and department</span> <span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span> <span class="n">inputs</span><span class="o">=</span><span class="p">[</span><span class="n">title_input</span><span class="p">,</span> <span class="n">body_input</span><span class="p">,</span> <span class="n">tags_input</span><span class="p">],</span> <span class="n">outputs</span><span class="o">=</span><span class="p">[</span><span class="n">priority_pred</span><span class="p">,</span> <span class="n">department_pred</span><span class="p">],</span> <span class="p">)</span> </code></pre></div> <p>Calling <code>summary()</code> can help you check the output shape of each layer:</p> <div class="codehilite"><pre><span></span><code><span class="n">model</span><span class="o">.</span><span class="n">summary</span><span class="p">()</span> </code></pre></div> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold">Model: "functional_1"</span> </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace">┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓ ┃<span style="font-weight: bold"> Layer (type) </span>┃<span style="font-weight: bold"> Output Shape </span>┃<span style="font-weight: bold"> Param # </span>┃<span style="font-weight: bold"> Connected to </span>┃ ┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩ │ title (<span style="color: #0087ff; text-decoration-color: #0087ff">InputLayer</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ - │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ body (<span style="color: #0087ff; text-decoration-color: #0087ff">InputLayer</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ - │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ embedding │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">640,000</span> │ title[<span style="color: #00af00; text-decoration-color: #00af00">0</span>][<span style="color: #00af00; text-decoration-color: #00af00">0</span>] │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Embedding</span>) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ embedding_1 │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">640,000</span> │ body[<span style="color: #00af00; text-decoration-color: #00af00">0</span>][<span style="color: #00af00; text-decoration-color: #00af00">0</span>] │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Embedding</span>) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ lstm (<span style="color: #0087ff; text-decoration-color: #0087ff">LSTM</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">98,816</span> │ embedding[<span style="color: #00af00; text-decoration-color: #00af00">0</span>][<span style="color: #00af00; text-decoration-color: #00af00">0</span>] │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ lstm_1 (<span style="color: #0087ff; text-decoration-color: #0087ff">LSTM</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">12,416</span> │ embedding_1[<span style="color: #00af00; text-decoration-color: #00af00">0</span>][<span style="color: #00af00; text-decoration-color: #00af00">0</span>] │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ tags (<span style="color: #0087ff; text-decoration-color: #0087ff">InputLayer</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">12</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ - │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ concatenate │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">172</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ lstm[<span style="color: #00af00; text-decoration-color: #00af00">0</span>][<span style="color: #00af00; text-decoration-color: #00af00">0</span>], │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Concatenate</span>) │ │ │ lstm_1[<span style="color: #00af00; text-decoration-color: #00af00">0</span>][<span style="color: #00af00; text-decoration-color: #00af00">0</span>], │ │ │ │ │ tags[<span style="color: #00af00; text-decoration-color: #00af00">0</span>][<span style="color: #00af00; text-decoration-color: #00af00">0</span>] │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ priority (<span style="color: #0087ff; text-decoration-color: #0087ff">Dense</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">1</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">173</span> │ concatenate[<span style="color: #00af00; text-decoration-color: #00af00">0</span>][<span style="color: #00af00; text-decoration-color: #00af00">0</span>] │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ department (<span style="color: #0087ff; text-decoration-color: #0087ff">Dense</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">4</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">692</span> │ concatenate[<span style="color: #00af00; text-decoration-color: #00af00">0</span>][<span style="color: #00af00; text-decoration-color: #00af00">0</span>] │ └─────────────────────┴───────────────────┴─────────┴──────────────────────┘ </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Total params: </span><span style="color: #00af00; text-decoration-color: #00af00">1,392,097</span> (5.31 MB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">1,392,097</span> (5.31 MB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Non-trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">0</span> (0.00 B) </pre> <p>You can also visualize the entire network topology alongside output shapes using <code>plot_model</code>:</p> <div class="codehilite"><pre><span></span><code><span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">plot_model</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">show_shapes</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> </code></pre></div> <p><img alt="png" src="/img/examples/keras_recipes/debugging_tips/debugging_tips_15_0.png" /></p> <p>With this plot, any connectivity-level error becomes immediately obvious.</p> <hr /> <h2 id="tip-3-to-debug-what-happens-during-fit-use-runeagerlytrue">Tip 3: to debug what happens during <code>fit()</code>, use <code>run_eagerly=True</code></h2> <p>The <code>fit()</code> method is fast: it runs a well-optimized, fully-compiled computation graph. That's great for performance, but it also means that the code you're executing isn't the Python code you've written. This can be problematic when debugging. As you may recall, Python is slow – so we use it as a staging language, not as an execution language.</p> <p>Thankfully, there's an easy way to run your code in "debug mode", fully eagerly: pass <code>run_eagerly=True</code> to <code>compile()</code>. Your call to <code>fit()</code> will now get executed line by line, without any optimization. It's slower, but it makes it possible to print the value of intermediate tensors, or to use a Python debugger. Great for debugging.</p> <p>Here's a basic example: let's write a really simple model with a custom <code>train_step()</code> method. Our model just implements gradient descent, but instead of first-order gradients, it uses a combination of first-order and second-order gradients. Pretty simple so far.</p> <p>Can you spot what we're doing wrong?</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">MyModel</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">):</span> <span class="k">def</span> <span class="nf">train_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">):</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span> <span class="o">=</span> <span class="n">data</span> <span class="n">trainable_vars</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">trainable_variables</span> <span class="k">with</span> <span class="n">tf</span><span class="o">.</span><span class="n">GradientTape</span><span class="p">()</span> <span class="k">as</span> <span class="n">tape2</span><span class="p">:</span> <span class="k">with</span> <span class="n">tf</span><span class="o">.</span><span class="n">GradientTape</span><span class="p">()</span> <span class="k">as</span> <span class="n">tape1</span><span class="p">:</span> <span class="n">y_pred</span> <span class="o">=</span> <span class="bp">self</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="c1"># Forward pass</span> <span class="c1"># Compute the loss value</span> <span class="c1"># (the loss function is configured in `compile()`)</span> <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">compute_loss</span><span class="p">(</span><span class="n">y</span><span class="o">=</span><span class="n">targets</span><span class="p">,</span> <span class="n">y_pred</span><span class="o">=</span><span class="n">y_pred</span><span class="p">)</span> <span class="c1"># Compute first-order gradients</span> <span class="n">dl_dw</span> <span class="o">=</span> <span class="n">tape1</span><span class="o">.</span><span class="n">gradient</span><span class="p">(</span><span class="n">loss</span><span class="p">,</span> <span class="n">trainable_vars</span><span class="p">)</span> <span class="c1"># Compute second-order gradients</span> <span class="n">d2l_dw2</span> <span class="o">=</span> <span class="n">tape2</span><span class="o">.</span><span class="n">gradient</span><span class="p">(</span><span class="n">dl_dw</span><span class="p">,</span> <span class="n">trainable_vars</span><span class="p">)</span> <span class="c1"># Combine first-order and second-order gradients</span> <span class="n">grads</span> <span class="o">=</span> <span class="p">[</span><span class="mf">0.5</span> <span class="o">*</span> <span class="n">w1</span> <span class="o">+</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">w2</span> <span class="k">for</span> <span class="p">(</span><span class="n">w1</span><span class="p">,</span> <span class="n">w2</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">d2l_dw2</span><span class="p">,</span> <span class="n">dl_dw</span><span class="p">)]</span> <span class="c1"># Update weights</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="n">grads</span><span class="p">,</span> <span class="n">trainable_vars</span><span class="p">))</span> <span class="c1"># Update metrics (includes the metric that tracks the loss)</span> <span class="k">for</span> <span class="n">metric</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">metrics</span><span class="p">:</span> <span class="k">if</span> <span class="n">metric</span><span class="o">.</span><span class="n">name</span> <span class="o">==</span> <span class="s2">"loss"</span><span class="p">:</span> <span class="n">metric</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">loss</span><span class="p">)</span> <span class="k">else</span><span class="p">:</span> <span class="n">metric</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">targets</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">)</span> <span class="c1"># Return a dict mapping metric names to current value</span> <span class="k">return</span> <span class="p">{</span><span class="n">m</span><span class="o">.</span><span class="n">name</span><span class="p">:</span> <span class="n">m</span><span class="o">.</span><span class="n">result</span><span class="p">()</span> <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">metrics</span><span class="p">}</span> </code></pre></div> <p>Let's train a one-layer model on MNIST with this custom loss function.</p> <p>We pick, somewhat at random, a batch size of 1024 and a learning rate of 0.1. The general idea being to use larger batches and a larger learning rate than usual, since our "improved" gradients should lead us to quicker convergence.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Construct an instance of MyModel</span> <span class="k">def</span> <span class="nf">get_model</span><span class="p">():</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">784</span><span class="p">,))</span> <span class="n">intermediate</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">256</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"softmax"</span><span class="p">)(</span><span class="n">intermediate</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">MyModel</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="p">)</span> <span class="k">return</span> <span class="n">model</span> <span class="c1"># Prepare data</span> <span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">),</span> <span class="n">_</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">datasets</span><span class="o">.</span><span class="n">mnist</span><span class="o">.</span><span class="n">load_data</span><span class="p">()</span> <span class="n">x_train</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">784</span><span class="p">))</span> <span class="o">/</span> <span class="mi">255</span> <span class="n">model</span> <span class="o">=</span> <span class="n">get_model</span><span class="p">()</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">SGD</span><span class="p">(</span><span class="n">learning_rate</span><span class="o">=</span><span class="mf">1e-2</span><span class="p">),</span> <span class="n">loss</span><span class="o">=</span><span class="s2">"sparse_categorical_crossentropy"</span><span class="p">,</span> <span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">1024</span><span class="p">,</span> <span class="n">validation_split</span><span class="o">=</span><span class="mf">0.1</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 1/3 53/53 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 2.4264 - val_loss: 2.3036 Epoch 2/3 53/53 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - loss: 2.3111 - val_loss: 2.3387 Epoch 3/3 53/53 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - loss: 2.3442 - val_loss: 2.3697 <keras.src.callbacks.history.History at 0x29a899600> </code></pre></div> </div> <p>Oh no, it doesn't converge! Something is not working as planned.</p> <p>Time for some step-by-step printing of what's going on with our gradients.</p> <p>We add various <code>print</code> statements in the <code>train_step</code> method, and we make sure to pass <code>run_eagerly=True</code> to <code>compile()</code> to run our code step-by-step, eagerly.</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">MyModel</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">):</span> <span class="k">def</span> <span class="nf">train_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">):</span> <span class="nb">print</span><span class="p">()</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"----Start of step: </span><span class="si">%d</span><span class="s2">"</span> <span class="o">%</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">step_counter</span><span class="p">,))</span> <span class="bp">self</span><span class="o">.</span><span class="n">step_counter</span> <span class="o">+=</span> <span class="mi">1</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span> <span class="o">=</span> <span class="n">data</span> <span class="n">trainable_vars</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">trainable_variables</span> <span class="k">with</span> <span class="n">tf</span><span class="o">.</span><span class="n">GradientTape</span><span class="p">()</span> <span class="k">as</span> <span class="n">tape2</span><span class="p">:</span> <span class="k">with</span> <span class="n">tf</span><span class="o">.</span><span class="n">GradientTape</span><span class="p">()</span> <span class="k">as</span> <span class="n">tape1</span><span class="p">:</span> <span class="n">y_pred</span> <span class="o">=</span> <span class="bp">self</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="c1"># Forward pass</span> <span class="c1"># Compute the loss value</span> <span class="c1"># (the loss function is configured in `compile()`)</span> <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">compute_loss</span><span class="p">(</span><span class="n">y</span><span class="o">=</span><span class="n">targets</span><span class="p">,</span> <span class="n">y_pred</span><span class="o">=</span><span class="n">y_pred</span><span class="p">)</span> <span class="c1"># Compute first-order gradients</span> <span class="n">dl_dw</span> <span class="o">=</span> <span class="n">tape1</span><span class="o">.</span><span class="n">gradient</span><span class="p">(</span><span class="n">loss</span><span class="p">,</span> <span class="n">trainable_vars</span><span class="p">)</span> <span class="c1"># Compute second-order gradients</span> <span class="n">d2l_dw2</span> <span class="o">=</span> <span class="n">tape2</span><span class="o">.</span><span class="n">gradient</span><span class="p">(</span><span class="n">dl_dw</span><span class="p">,</span> <span class="n">trainable_vars</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Max of dl_dw[0]: </span><span class="si">%.4f</span><span class="s2">"</span> <span class="o">%</span> <span class="n">tf</span><span class="o">.</span><span class="n">reduce_max</span><span class="p">(</span><span class="n">dl_dw</span><span class="p">[</span><span class="mi">0</span><span class="p">]))</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Min of dl_dw[0]: </span><span class="si">%.4f</span><span class="s2">"</span> <span class="o">%</span> <span class="n">tf</span><span class="o">.</span><span class="n">reduce_min</span><span class="p">(</span><span class="n">dl_dw</span><span class="p">[</span><span class="mi">0</span><span class="p">]))</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Mean of dl_dw[0]: </span><span class="si">%.4f</span><span class="s2">"</span> <span class="o">%</span> <span class="n">tf</span><span class="o">.</span><span class="n">reduce_mean</span><span class="p">(</span><span class="n">dl_dw</span><span class="p">[</span><span class="mi">0</span><span class="p">]))</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"-"</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Max of d2l_dw2[0]: </span><span class="si">%.4f</span><span class="s2">"</span> <span class="o">%</span> <span class="n">tf</span><span class="o">.</span><span class="n">reduce_max</span><span class="p">(</span><span class="n">d2l_dw2</span><span class="p">[</span><span class="mi">0</span><span class="p">]))</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Min of d2l_dw2[0]: </span><span class="si">%.4f</span><span class="s2">"</span> <span class="o">%</span> <span class="n">tf</span><span class="o">.</span><span class="n">reduce_min</span><span class="p">(</span><span class="n">d2l_dw2</span><span class="p">[</span><span class="mi">0</span><span class="p">]))</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Mean of d2l_dw2[0]: </span><span class="si">%.4f</span><span class="s2">"</span> <span class="o">%</span> <span class="n">tf</span><span class="o">.</span><span class="n">reduce_mean</span><span class="p">(</span><span class="n">d2l_dw2</span><span class="p">[</span><span class="mi">0</span><span class="p">]))</span> <span class="c1"># Combine first-order and second-order gradients</span> <span class="n">grads</span> <span class="o">=</span> <span class="p">[</span><span class="mf">0.5</span> <span class="o">*</span> <span class="n">w1</span> <span class="o">+</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">w2</span> <span class="k">for</span> <span class="p">(</span><span class="n">w1</span><span class="p">,</span> <span class="n">w2</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">d2l_dw2</span><span class="p">,</span> <span class="n">dl_dw</span><span class="p">)]</span> <span class="c1"># Update weights</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="n">grads</span><span class="p">,</span> <span class="n">trainable_vars</span><span class="p">))</span> <span class="c1"># Update metrics (includes the metric that tracks the loss)</span> <span class="k">for</span> <span class="n">metric</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">metrics</span><span class="p">:</span> <span class="k">if</span> <span class="n">metric</span><span class="o">.</span><span class="n">name</span> <span class="o">==</span> <span class="s2">"loss"</span><span class="p">:</span> <span class="n">metric</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">loss</span><span class="p">)</span> <span class="k">else</span><span class="p">:</span> <span class="n">metric</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">targets</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">)</span> <span class="c1"># Return a dict mapping metric names to current value</span> <span class="k">return</span> <span class="p">{</span><span class="n">m</span><span class="o">.</span><span class="n">name</span><span class="p">:</span> <span class="n">m</span><span class="o">.</span><span class="n">result</span><span class="p">()</span> <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">metrics</span><span class="p">}</span> <span class="n">model</span> <span class="o">=</span> <span class="n">get_model</span><span class="p">()</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">SGD</span><span class="p">(</span><span class="n">learning_rate</span><span class="o">=</span><span class="mf">1e-2</span><span class="p">),</span> <span class="n">loss</span><span class="o">=</span><span class="s2">"sparse_categorical_crossentropy"</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="s2">"sparse_categorical_accuracy"</span><span class="p">],</span> <span class="n">run_eagerly</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">step_counter</span> <span class="o">=</span> <span class="mi">0</span> <span class="c1"># We pass epochs=1 and steps_per_epoch=10 to only run 10 steps of training.</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">1024</span><span class="p">,</span> <span class="n">verbose</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">steps_per_epoch</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>----Start of step: 0 Max of dl_dw[0]: 0.0332 Min of dl_dw[0]: -0.0288 Mean of dl_dw[0]: 0.0003 - Max of d2l_dw2[0]: 5.2691 Min of d2l_dw2[0]: -2.6968 Mean of d2l_dw2[0]: 0.0981 </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>----Start of step: 1 Max of dl_dw[0]: 0.0445 Min of dl_dw[0]: -0.0169 Mean of dl_dw[0]: 0.0013 - Max of d2l_dw2[0]: 3.3575 Min of d2l_dw2[0]: -1.9024 Mean of d2l_dw2[0]: 0.0726 </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>----Start of step: 2 Max of dl_dw[0]: 0.0669 Min of dl_dw[0]: -0.0153 Mean of dl_dw[0]: 0.0013 - Max of d2l_dw2[0]: 5.0661 Min of d2l_dw2[0]: -1.7168 Mean of d2l_dw2[0]: 0.0809 </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>----Start of step: 3 Max of dl_dw[0]: 0.0545 Min of dl_dw[0]: -0.0125 Mean of dl_dw[0]: 0.0008 - Max of d2l_dw2[0]: 6.5223 Min of d2l_dw2[0]: -0.6604 Mean of d2l_dw2[0]: 0.0991 </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>----Start of step: 4 Max of dl_dw[0]: 0.0247 Min of dl_dw[0]: -0.0152 Mean of dl_dw[0]: -0.0001 - Max of d2l_dw2[0]: 2.8030 Min of d2l_dw2[0]: -0.1156 Mean of d2l_dw2[0]: 0.0321 </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>----Start of step: 5 Max of dl_dw[0]: 0.0051 Min of dl_dw[0]: -0.0096 Mean of dl_dw[0]: -0.0001 - Max of d2l_dw2[0]: 0.2545 Min of d2l_dw2[0]: -0.0284 Mean of d2l_dw2[0]: 0.0079 </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>----Start of step: 6 Max of dl_dw[0]: 0.0041 Min of dl_dw[0]: -0.0102 Mean of dl_dw[0]: -0.0001 - Max of d2l_dw2[0]: 0.2198 Min of d2l_dw2[0]: -0.0175 Mean of d2l_dw2[0]: 0.0069 </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>----Start of step: 7 Max of dl_dw[0]: 0.0035 Min of dl_dw[0]: -0.0086 Mean of dl_dw[0]: -0.0001 - Max of d2l_dw2[0]: 0.1485 Min of d2l_dw2[0]: -0.0175 Mean of d2l_dw2[0]: 0.0060 </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>----Start of step: 8 Max of dl_dw[0]: 0.0039 Min of dl_dw[0]: -0.0094 Mean of dl_dw[0]: -0.0001 - Max of d2l_dw2[0]: 0.1454 Min of d2l_dw2[0]: -0.0130 Mean of d2l_dw2[0]: 0.0061 </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>----Start of step: 9 Max of dl_dw[0]: 0.0028 Min of dl_dw[0]: -0.0087 Mean of dl_dw[0]: -0.0001 - Max of d2l_dw2[0]: 0.1491 Min of d2l_dw2[0]: -0.0326 Mean of d2l_dw2[0]: 0.0058 <keras.src.callbacks.history.History at 0x2a0d1e440> </code></pre></div> </div> <p>What did we learn?</p> <ul> <li>The first order and second order gradients can have values that differ by orders of magnitudes.</li> <li>Sometimes, they may not even have the same sign.</li> <li>Their values can vary greatly at each step.</li> </ul> <p>This leads us to an obvious idea: let's normalize the gradients before combining them.</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">MyModel</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">):</span> <span class="k">def</span> <span class="nf">train_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">):</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span> <span class="o">=</span> <span class="n">data</span> <span class="n">trainable_vars</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">trainable_variables</span> <span class="k">with</span> <span class="n">tf</span><span class="o">.</span><span class="n">GradientTape</span><span class="p">()</span> <span class="k">as</span> <span class="n">tape2</span><span class="p">:</span> <span class="k">with</span> <span class="n">tf</span><span class="o">.</span><span class="n">GradientTape</span><span class="p">()</span> <span class="k">as</span> <span class="n">tape1</span><span class="p">:</span> <span class="n">y_pred</span> <span class="o">=</span> <span class="bp">self</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="c1"># Forward pass</span> <span class="c1"># Compute the loss value</span> <span class="c1"># (the loss function is configured in `compile()`)</span> <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">compute_loss</span><span class="p">(</span><span class="n">y</span><span class="o">=</span><span class="n">targets</span><span class="p">,</span> <span class="n">y_pred</span><span class="o">=</span><span class="n">y_pred</span><span class="p">)</span> <span class="c1"># Compute first-order gradients</span> <span class="n">dl_dw</span> <span class="o">=</span> <span class="n">tape1</span><span class="o">.</span><span class="n">gradient</span><span class="p">(</span><span class="n">loss</span><span class="p">,</span> <span class="n">trainable_vars</span><span class="p">)</span> <span class="c1"># Compute second-order gradients</span> <span class="n">d2l_dw2</span> <span class="o">=</span> <span class="n">tape2</span><span class="o">.</span><span class="n">gradient</span><span class="p">(</span><span class="n">dl_dw</span><span class="p">,</span> <span class="n">trainable_vars</span><span class="p">)</span> <span class="n">dl_dw</span> <span class="o">=</span> <span class="p">[</span><span class="n">tf</span><span class="o">.</span><span class="n">math</span><span class="o">.</span><span class="n">l2_normalize</span><span class="p">(</span><span class="n">w</span><span class="p">)</span> <span class="k">for</span> <span class="n">w</span> <span class="ow">in</span> <span class="n">dl_dw</span><span class="p">]</span> <span class="n">d2l_dw2</span> <span class="o">=</span> <span class="p">[</span><span class="n">tf</span><span class="o">.</span><span class="n">math</span><span class="o">.</span><span class="n">l2_normalize</span><span class="p">(</span><span class="n">w</span><span class="p">)</span> <span class="k">for</span> <span class="n">w</span> <span class="ow">in</span> <span class="n">d2l_dw2</span><span class="p">]</span> <span class="c1"># Combine first-order and second-order gradients</span> <span class="n">grads</span> <span class="o">=</span> <span class="p">[</span><span class="mf">0.5</span> <span class="o">*</span> <span class="n">w1</span> <span class="o">+</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="n">w2</span> <span class="k">for</span> <span class="p">(</span><span class="n">w1</span><span class="p">,</span> <span class="n">w2</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">d2l_dw2</span><span class="p">,</span> <span class="n">dl_dw</span><span class="p">)]</span> <span class="c1"># Update weights</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="n">grads</span><span class="p">,</span> <span class="n">trainable_vars</span><span class="p">))</span> <span class="c1"># Update metrics (includes the metric that tracks the loss)</span> <span class="k">for</span> <span class="n">metric</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">metrics</span><span class="p">:</span> <span class="k">if</span> <span class="n">metric</span><span class="o">.</span><span class="n">name</span> <span class="o">==</span> <span class="s2">"loss"</span><span class="p">:</span> <span class="n">metric</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">loss</span><span class="p">)</span> <span class="k">else</span><span class="p">:</span> <span class="n">metric</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">targets</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">)</span> <span class="c1"># Return a dict mapping metric names to current value</span> <span class="k">return</span> <span class="p">{</span><span class="n">m</span><span class="o">.</span><span class="n">name</span><span class="p">:</span> <span class="n">m</span><span class="o">.</span><span class="n">result</span><span class="p">()</span> <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">metrics</span><span class="p">}</span> <span class="n">model</span> <span class="o">=</span> <span class="n">get_model</span><span class="p">()</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">SGD</span><span class="p">(</span><span class="n">learning_rate</span><span class="o">=</span><span class="mf">1e-2</span><span class="p">),</span> <span class="n">loss</span><span class="o">=</span><span class="s2">"sparse_categorical_crossentropy"</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="s2">"sparse_categorical_accuracy"</span><span class="p">],</span> <span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">1024</span><span class="p">,</span> <span class="n">validation_split</span><span class="o">=</span><span class="mf">0.1</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 1/5 53/53 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - sparse_categorical_accuracy: 0.1250 - loss: 2.3185 - val_loss: 2.0502 - val_sparse_categorical_accuracy: 0.3373 Epoch 2/5 53/53 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step - sparse_categorical_accuracy: 0.3966 - loss: 1.9934 - val_loss: 1.8032 - val_sparse_categorical_accuracy: 0.5698 Epoch 3/5 53/53 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - sparse_categorical_accuracy: 0.5663 - loss: 1.7784 - val_loss: 1.6241 - val_sparse_categorical_accuracy: 0.6470 Epoch 4/5 53/53 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - sparse_categorical_accuracy: 0.6135 - loss: 1.6256 - val_loss: 1.5010 - val_sparse_categorical_accuracy: 0.6595 Epoch 5/5 53/53 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - sparse_categorical_accuracy: 0.6216 - loss: 1.5173 - val_loss: 1.4169 - val_sparse_categorical_accuracy: 0.6625 <keras.src.callbacks.history.History at 0x2a0d4c640> </code></pre></div> </div> <p>Now, training converges! It doesn't work well at all, but at least the model learns something.</p> <p>After spending a few minutes tuning parameters, we get to the following configuration that works somewhat well (achieves 97% validation accuracy and seems reasonably robust to overfitting):</p> <ul> <li>Use <code>0.2 * w1 + 0.8 * w2</code> for combining gradients.</li> <li>Use a learning rate that decays linearly over time.</li> </ul> <p>I'm not going to say that the idea works – this isn't at all how you're supposed to do second-order optimization (pointers: see the Newton & Gauss-Newton methods, quasi-Newton methods, and BFGS). But hopefully this demonstration gave you an idea of how you can debug your way out of uncomfortable training situations.</p> <p>Remember: use <code>run_eagerly=True</code> for debugging what happens in <code>fit()</code>. And when your code is finally working as expected, make sure to remove this flag in order to get the best runtime performance!</p> <p>Here's our final training run:</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">MyModel</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">):</span> <span class="k">def</span> <span class="nf">train_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">):</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span> <span class="o">=</span> <span class="n">data</span> <span class="n">trainable_vars</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">trainable_variables</span> <span class="k">with</span> <span class="n">tf</span><span class="o">.</span><span class="n">GradientTape</span><span class="p">()</span> <span class="k">as</span> <span class="n">tape2</span><span class="p">:</span> <span class="k">with</span> <span class="n">tf</span><span class="o">.</span><span class="n">GradientTape</span><span class="p">()</span> <span class="k">as</span> <span class="n">tape1</span><span class="p">:</span> <span class="n">y_pred</span> <span class="o">=</span> <span class="bp">self</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="c1"># Forward pass</span> <span class="c1"># Compute the loss value</span> <span class="c1"># (the loss function is configured in `compile()`)</span> <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">compute_loss</span><span class="p">(</span><span class="n">y</span><span class="o">=</span><span class="n">targets</span><span class="p">,</span> <span class="n">y_pred</span><span class="o">=</span><span class="n">y_pred</span><span class="p">)</span> <span class="c1"># Compute first-order gradients</span> <span class="n">dl_dw</span> <span class="o">=</span> <span class="n">tape1</span><span class="o">.</span><span class="n">gradient</span><span class="p">(</span><span class="n">loss</span><span class="p">,</span> <span class="n">trainable_vars</span><span class="p">)</span> <span class="c1"># Compute second-order gradients</span> <span class="n">d2l_dw2</span> <span class="o">=</span> <span class="n">tape2</span><span class="o">.</span><span class="n">gradient</span><span class="p">(</span><span class="n">dl_dw</span><span class="p">,</span> <span class="n">trainable_vars</span><span class="p">)</span> <span class="n">dl_dw</span> <span class="o">=</span> <span class="p">[</span><span class="n">tf</span><span class="o">.</span><span class="n">math</span><span class="o">.</span><span class="n">l2_normalize</span><span class="p">(</span><span class="n">w</span><span class="p">)</span> <span class="k">for</span> <span class="n">w</span> <span class="ow">in</span> <span class="n">dl_dw</span><span class="p">]</span> <span class="n">d2l_dw2</span> <span class="o">=</span> <span class="p">[</span><span class="n">tf</span><span class="o">.</span><span class="n">math</span><span class="o">.</span><span class="n">l2_normalize</span><span class="p">(</span><span class="n">w</span><span class="p">)</span> <span class="k">for</span> <span class="n">w</span> <span class="ow">in</span> <span class="n">d2l_dw2</span><span class="p">]</span> <span class="c1"># Combine first-order and second-order gradients</span> <span class="n">grads</span> <span class="o">=</span> <span class="p">[</span><span class="mf">0.2</span> <span class="o">*</span> <span class="n">w1</span> <span class="o">+</span> <span class="mf">0.8</span> <span class="o">*</span> <span class="n">w2</span> <span class="k">for</span> <span class="p">(</span><span class="n">w1</span><span class="p">,</span> <span class="n">w2</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">d2l_dw2</span><span class="p">,</span> <span class="n">dl_dw</span><span class="p">)]</span> <span class="c1"># Update weights</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="n">grads</span><span class="p">,</span> <span class="n">trainable_vars</span><span class="p">))</span> <span class="c1"># Update metrics (includes the metric that tracks the loss)</span> <span class="k">for</span> <span class="n">metric</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">metrics</span><span class="p">:</span> <span class="k">if</span> <span class="n">metric</span><span class="o">.</span><span class="n">name</span> <span class="o">==</span> <span class="s2">"loss"</span><span class="p">:</span> <span class="n">metric</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">loss</span><span class="p">)</span> <span class="k">else</span><span class="p">:</span> <span class="n">metric</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">targets</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">)</span> <span class="c1"># Return a dict mapping metric names to current value</span> <span class="k">return</span> <span class="p">{</span><span class="n">m</span><span class="o">.</span><span class="n">name</span><span class="p">:</span> <span class="n">m</span><span class="o">.</span><span class="n">result</span><span class="p">()</span> <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">metrics</span><span class="p">}</span> <span class="n">model</span> <span class="o">=</span> <span class="n">get_model</span><span class="p">()</span> <span class="n">lr</span> <span class="o">=</span> <span class="n">learning_rate</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">schedules</span><span class="o">.</span><span class="n">InverseTimeDecay</span><span class="p">(</span> <span class="n">initial_learning_rate</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">decay_steps</span><span class="o">=</span><span class="mi">25</span><span class="p">,</span> <span class="n">decay_rate</span><span class="o">=</span><span class="mf">0.1</span> <span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">SGD</span><span class="p">(</span><span class="n">lr</span><span class="p">),</span> <span class="n">loss</span><span class="o">=</span><span class="s2">"sparse_categorical_crossentropy"</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="s2">"sparse_categorical_accuracy"</span><span class="p">],</span> <span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">50</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">2048</span><span class="p">,</span> <span class="n">validation_split</span><span class="o">=</span><span class="mf">0.1</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 1/50 27/27 ━━━━━━━━━━━━━━━━━━━━ 1s 14ms/step - sparse_categorical_accuracy: 0.5056 - loss: 1.7508 - val_loss: 0.6378 - val_sparse_categorical_accuracy: 0.8658 Epoch 2/50 27/27 ━━━━━━━━━━━━━━━━━━━━ 0s 10ms/step - sparse_categorical_accuracy: 0.8407 - loss: 0.6323 - val_loss: 0.4039 - val_sparse_categorical_accuracy: 0.8970 Epoch 3/50 27/27 ━━━━━━━━━━━━━━━━━━━━ 0s 10ms/step - sparse_categorical_accuracy: 0.8807 - loss: 0.4472 - val_loss: 0.3243 - val_sparse_categorical_accuracy: 0.9120 Epoch 4/50 27/27 ━━━━━━━━━━━━━━━━━━━━ 0s 10ms/step - sparse_categorical_accuracy: 0.8947 - loss: 0.3781 - val_loss: 0.2861 - val_sparse_categorical_accuracy: 0.9235 Epoch 5/50 27/27 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - sparse_categorical_accuracy: 0.9022 - loss: 0.3453 - val_loss: 0.2622 - val_sparse_categorical_accuracy: 0.9288 Epoch 6/50 27/27 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - sparse_categorical_accuracy: 0.9093 - loss: 0.3243 - val_loss: 0.2523 - val_sparse_categorical_accuracy: 0.9303 Epoch 7/50 27/27 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - sparse_categorical_accuracy: 0.9148 - loss: 0.3021 - val_loss: 0.2362 - val_sparse_categorical_accuracy: 0.9338 Epoch 8/50 27/27 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - sparse_categorical_accuracy: 0.9184 - loss: 0.2899 - val_loss: 0.2289 - val_sparse_categorical_accuracy: 0.9365 Epoch 9/50 27/27 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - sparse_categorical_accuracy: 0.9212 - loss: 0.2784 - val_loss: 0.2183 - val_sparse_categorical_accuracy: 0.9383 Epoch 10/50 27/27 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - sparse_categorical_accuracy: 0.9246 - loss: 0.2670 - val_loss: 0.2097 - val_sparse_categorical_accuracy: 0.9405 Epoch 11/50 27/27 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - sparse_categorical_accuracy: 0.9267 - loss: 0.2563 - val_loss: 0.2063 - val_sparse_categorical_accuracy: 0.9442 Epoch 12/50 27/27 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - sparse_categorical_accuracy: 0.9313 - loss: 0.2412 - val_loss: 0.1965 - val_sparse_categorical_accuracy: 0.9458 Epoch 13/50 27/27 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - sparse_categorical_accuracy: 0.9324 - loss: 0.2411 - val_loss: 0.1917 - val_sparse_categorical_accuracy: 0.9472 Epoch 14/50 27/27 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - sparse_categorical_accuracy: 0.9359 - loss: 0.2260 - val_loss: 0.1861 - val_sparse_categorical_accuracy: 0.9495 Epoch 15/50 27/27 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - sparse_categorical_accuracy: 0.9374 - loss: 0.2234 - val_loss: 0.1804 - val_sparse_categorical_accuracy: 0.9517 Epoch 16/50 27/27 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - sparse_categorical_accuracy: 0.9382 - loss: 0.2196 - val_loss: 0.1761 - val_sparse_categorical_accuracy: 0.9528 Epoch 17/50 27/27 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - sparse_categorical_accuracy: 0.9417 - loss: 0.2076 - val_loss: 0.1709 - val_sparse_categorical_accuracy: 0.9557 Epoch 18/50 27/27 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - sparse_categorical_accuracy: 0.9423 - loss: 0.2032 - val_loss: 0.1664 - val_sparse_categorical_accuracy: 0.9555 Epoch 19/50 27/27 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - sparse_categorical_accuracy: 0.9444 - loss: 0.1953 - val_loss: 0.1616 - val_sparse_categorical_accuracy: 0.9582 Epoch 20/50 27/27 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - sparse_categorical_accuracy: 0.9451 - loss: 0.1916 - val_loss: 0.1597 - val_sparse_categorical_accuracy: 0.9592 Epoch 21/50 27/27 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - sparse_categorical_accuracy: 0.9473 - loss: 0.1866 - val_loss: 0.1563 - val_sparse_categorical_accuracy: 0.9615 Epoch 22/50 27/27 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - sparse_categorical_accuracy: 0.9486 - loss: 0.1818 - val_loss: 0.1520 - val_sparse_categorical_accuracy: 0.9617 Epoch 23/50 27/27 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - sparse_categorical_accuracy: 0.9502 - loss: 0.1794 - val_loss: 0.1499 - val_sparse_categorical_accuracy: 0.9635 Epoch 24/50 27/27 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - sparse_categorical_accuracy: 0.9502 - loss: 0.1759 - val_loss: 0.1466 - val_sparse_categorical_accuracy: 0.9640 Epoch 25/50 27/27 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - sparse_categorical_accuracy: 0.9515 - loss: 0.1714 - val_loss: 0.1437 - val_sparse_categorical_accuracy: 0.9645 Epoch 26/50 27/27 ━━━━━━━━━━━━━━━━━━━━ 0s 14ms/step - sparse_categorical_accuracy: 0.9535 - loss: 0.1649 - val_loss: 0.1435 - val_sparse_categorical_accuracy: 0.9640 Epoch 27/50 27/27 ━━━━━━━━━━━━━━━━━━━━ 0s 13ms/step - sparse_categorical_accuracy: 0.9548 - loss: 0.1628 - val_loss: 0.1411 - val_sparse_categorical_accuracy: 0.9650 Epoch 28/50 27/27 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - sparse_categorical_accuracy: 0.9541 - loss: 0.1620 - val_loss: 0.1384 - val_sparse_categorical_accuracy: 0.9655 Epoch 29/50 27/27 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - sparse_categorical_accuracy: 0.9564 - loss: 0.1560 - val_loss: 0.1359 - val_sparse_categorical_accuracy: 0.9668 Epoch 30/50 27/27 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - sparse_categorical_accuracy: 0.9577 - loss: 0.1547 - val_loss: 0.1338 - val_sparse_categorical_accuracy: 0.9672 Epoch 31/50 27/27 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - sparse_categorical_accuracy: 0.9569 - loss: 0.1520 - val_loss: 0.1329 - val_sparse_categorical_accuracy: 0.9663 Epoch 32/50 27/27 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - sparse_categorical_accuracy: 0.9582 - loss: 0.1478 - val_loss: 0.1320 - val_sparse_categorical_accuracy: 0.9675 Epoch 33/50 27/27 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - sparse_categorical_accuracy: 0.9582 - loss: 0.1483 - val_loss: 0.1292 - val_sparse_categorical_accuracy: 0.9670 Epoch 34/50 27/27 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - sparse_categorical_accuracy: 0.9594 - loss: 0.1448 - val_loss: 0.1274 - val_sparse_categorical_accuracy: 0.9677 Epoch 35/50 27/27 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - sparse_categorical_accuracy: 0.9587 - loss: 0.1452 - val_loss: 0.1262 - val_sparse_categorical_accuracy: 0.9678 Epoch 36/50 27/27 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - sparse_categorical_accuracy: 0.9603 - loss: 0.1418 - val_loss: 0.1251 - val_sparse_categorical_accuracy: 0.9677 Epoch 37/50 27/27 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - sparse_categorical_accuracy: 0.9603 - loss: 0.1402 - val_loss: 0.1238 - val_sparse_categorical_accuracy: 0.9682 Epoch 38/50 27/27 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - sparse_categorical_accuracy: 0.9618 - loss: 0.1382 - val_loss: 0.1228 - val_sparse_categorical_accuracy: 0.9680 Epoch 39/50 27/27 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - sparse_categorical_accuracy: 0.9630 - loss: 0.1335 - val_loss: 0.1213 - val_sparse_categorical_accuracy: 0.9695 Epoch 40/50 27/27 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - sparse_categorical_accuracy: 0.9629 - loss: 0.1327 - val_loss: 0.1198 - val_sparse_categorical_accuracy: 0.9698 Epoch 41/50 27/27 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - sparse_categorical_accuracy: 0.9639 - loss: 0.1323 - val_loss: 0.1191 - val_sparse_categorical_accuracy: 0.9695 Epoch 42/50 27/27 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - sparse_categorical_accuracy: 0.9629 - loss: 0.1346 - val_loss: 0.1183 - val_sparse_categorical_accuracy: 0.9692 Epoch 43/50 27/27 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - sparse_categorical_accuracy: 0.9661 - loss: 0.1262 - val_loss: 0.1182 - val_sparse_categorical_accuracy: 0.9700 Epoch 44/50 27/27 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - sparse_categorical_accuracy: 0.9652 - loss: 0.1274 - val_loss: 0.1163 - val_sparse_categorical_accuracy: 0.9702 Epoch 45/50 27/27 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - sparse_categorical_accuracy: 0.9650 - loss: 0.1259 - val_loss: 0.1154 - val_sparse_categorical_accuracy: 0.9708 Epoch 46/50 27/27 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step - sparse_categorical_accuracy: 0.9647 - loss: 0.1246 - val_loss: 0.1148 - val_sparse_categorical_accuracy: 0.9703 Epoch 47/50 27/27 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - sparse_categorical_accuracy: 0.9659 - loss: 0.1236 - val_loss: 0.1137 - val_sparse_categorical_accuracy: 0.9707 Epoch 48/50 27/27 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - sparse_categorical_accuracy: 0.9665 - loss: 0.1221 - val_loss: 0.1133 - val_sparse_categorical_accuracy: 0.9710 Epoch 49/50 27/27 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - sparse_categorical_accuracy: 0.9675 - loss: 0.1192 - val_loss: 0.1124 - val_sparse_categorical_accuracy: 0.9712 Epoch 50/50 27/27 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step - sparse_categorical_accuracy: 0.9664 - loss: 0.1214 - val_loss: 0.1112 - val_sparse_categorical_accuracy: 0.9707 <keras.src.callbacks.history.History at 0x29e76ae60> </code></pre></div> </div> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#keras-debugging-tips'>Keras debugging tips</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#tip-1-test-each-part-before-you-test-the-whole'>Tip 1: test each part before you test the whole</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#tip-2-use-modelsummary-and-plotmodel-to-check-layer-output-shapes'>Tip 2: use <code>model.summary()</code> and <code>plot_model()</code> to check layer output shapes</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#tip-3-to-debug-what-happens-during-fit-use-runeagerlytrue'>Tip 3: to debug what happens during <code>fit()</code>, use <code>run_eagerly=True</code></a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>