CINXE.COM
Writing your own callbacks
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/guides/writing_your_own_callbacks/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Writing your own callbacks"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Writing your own callbacks"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Writing your own callbacks</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link active" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-sublink" href="/guides/functional_api/">The Functional API</a> <a class="nav-sublink" href="/guides/sequential_model/">The Sequential model</a> <a class="nav-sublink" href="/guides/making_new_layers_and_models_via_subclassing/">Making new layers & models via subclassing</a> <a class="nav-sublink" href="/guides/training_with_built_in_methods/">Training & evaluation with the built-in methods</a> <a class="nav-sublink" href="/guides/custom_train_step_in_jax/">Customizing `fit()` with JAX</a> <a class="nav-sublink" href="/guides/custom_train_step_in_tensorflow/">Customizing `fit()` with TensorFlow</a> <a class="nav-sublink" href="/guides/custom_train_step_in_torch/">Customizing `fit()` with PyTorch</a> <a class="nav-sublink" href="/guides/writing_a_custom_training_loop_in_jax/">Writing a custom training loop in JAX</a> <a class="nav-sublink" href="/guides/writing_a_custom_training_loop_in_tensorflow/">Writing a custom training loop in TensorFlow</a> <a class="nav-sublink" href="/guides/writing_a_custom_training_loop_in_torch/">Writing a custom training loop in PyTorch</a> <a class="nav-sublink" href="/guides/serialization_and_saving/">Serialization & saving</a> <a class="nav-sublink" href="/guides/customizing_saving_and_serialization/">Customizing saving & serialization</a> <a class="nav-sublink active" href="/guides/writing_your_own_callbacks/">Writing your own callbacks</a> <a class="nav-sublink" href="/guides/transfer_learning/">Transfer learning & fine-tuning</a> <a class="nav-sublink" href="/guides/distributed_training_with_jax/">Distributed training with JAX</a> <a class="nav-sublink" href="/guides/distributed_training_with_tensorflow/">Distributed training with TensorFlow</a> <a class="nav-sublink" href="/guides/distributed_training_with_torch/">Distributed training with PyTorch</a> <a class="nav-sublink" href="/guides/distribution/">Distributed training with Keras 3</a> <a class="nav-sublink" href="/guides/migrating_to_keras_3/">Migrating Keras 2 code to Keras 3</a> <a class="nav-sublink" href="/guides/keras_tuner/">Hyperparameter Tuning</a> <a class="nav-sublink" href="/guides/keras_cv/">KerasCV</a> <a class="nav-sublink" href="/guides/keras_nlp/">KerasNLP</a> <a class="nav-sublink" href="/guides/keras_hub/">KerasHub</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparameter Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> <a class="nav-link" href="/keras_cv/" role="tab" aria-selected="">KerasCV: Computer Vision Workflows</a> <a class="nav-link" href="/keras_nlp/" role="tab" aria-selected="">KerasNLP: Natural Language Workflows</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/guides/'>Developer guides</a> / Writing your own callbacks </div> <div class='k-content'> <h1 id="writing-your-own-callbacks">Writing your own callbacks</h1> <p><strong>Authors:</strong> Rick Chao, Francois Chollet<br> <strong>Date created:</strong> 2019/03/20<br> <strong>Last modified:</strong> 2023/06/25<br> <strong>Description:</strong> Complete guide to writing new Keras callbacks.</p> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/guides/ipynb/writing_your_own_callbacks.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/guides/writing_your_own_callbacks.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="introduction">Introduction</h2> <p>A callback is a powerful tool to customize the behavior of a Keras model during training, evaluation, or inference. Examples include <a href="/api/callbacks/tensorboard#tensorboard-class"><code>keras.callbacks.TensorBoard</code></a> to visualize training progress and results with TensorBoard, or <a href="/api/callbacks/model_checkpoint#modelcheckpoint-class"><code>keras.callbacks.ModelCheckpoint</code></a> to periodically save your model during training.</p> <p>In this guide, you will learn what a Keras callback is, what it can do, and how you can build your own. We provide a few demos of simple callback applications to get you started.</p> <hr /> <h2 id="setup">Setup</h2> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> <span class="kn">import</span> <span class="nn">keras</span> </code></pre></div> <hr /> <h2 id="keras-callbacks-overview">Keras callbacks overview</h2> <p>All callbacks subclass the <a href="/api/callbacks/base_callback#callback-class"><code>keras.callbacks.Callback</code></a> class, and override a set of methods called at various stages of training, testing, and predicting. Callbacks are useful to get a view on internal states and statistics of the model during training.</p> <p>You can pass a list of callbacks (as the keyword argument <code>callbacks</code>) to the following model methods:</p> <ul> <li><code>keras.Model.fit()</code></li> <li><code>keras.Model.evaluate()</code></li> <li><code>keras.Model.predict()</code></li> </ul> <hr /> <h2 id="an-overview-of-callback-methods">An overview of callback methods</h2> <h3 id="global-methods">Global methods</h3> <h4 id="ontraintestpredictbeginself-logsnone"><code>on_(train|test|predict)_begin(self, logs=None)</code></h4> <p>Called at the beginning of <code>fit</code>/<code>evaluate</code>/<code>predict</code>.</p> <h4 id="ontraintestpredictendself-logsnone"><code>on_(train|test|predict)_end(self, logs=None)</code></h4> <p>Called at the end of <code>fit</code>/<code>evaluate</code>/<code>predict</code>.</p> <h3 id="batchlevel-methods-for-trainingtestingpredicting">Batch-level methods for training/testing/predicting</h3> <h4 id="ontraintestpredictbatchbeginself-batch-logsnone"><code>on_(train|test|predict)_batch_begin(self, batch, logs=None)</code></h4> <p>Called right before processing a batch during training/testing/predicting.</p> <h4 id="ontraintestpredictbatchendself-batch-logsnone"><code>on_(train|test|predict)_batch_end(self, batch, logs=None)</code></h4> <p>Called at the end of training/testing/predicting a batch. Within this method, <code>logs</code> is a dict containing the metrics results.</p> <h3 id="epochlevel-methods-training-only">Epoch-level methods (training only)</h3> <h4 id="onepochbeginself-epoch-logsnone"><code>on_epoch_begin(self, epoch, logs=None)</code></h4> <p>Called at the beginning of an epoch during training.</p> <h4 id="onepochendself-epoch-logsnone"><code>on_epoch_end(self, epoch, logs=None)</code></h4> <p>Called at the end of an epoch during training.</p> <hr /> <h2 id="a-basic-example">A basic example</h2> <p>Let's take a look at a concrete example. To get started, let's import tensorflow and define a simple Sequential Keras model:</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Define the Keras model to add callbacks to</span> <span class="k">def</span> <span class="nf">get_model</span><span class="p">():</span> <span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">()</span> <span class="n">model</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">1</span><span class="p">))</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">RMSprop</span><span class="p">(</span><span class="n">learning_rate</span><span class="o">=</span><span class="mf">0.1</span><span class="p">),</span> <span class="n">loss</span><span class="o">=</span><span class="s2">"mean_squared_error"</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="s2">"mean_absolute_error"</span><span class="p">],</span> <span class="p">)</span> <span class="k">return</span> <span class="n">model</span> </code></pre></div> <p>Then, load the MNIST data for training and testing from Keras datasets API:</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Load example MNIST data and pre-process it</span> <span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">),</span> <span class="p">(</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">)</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">datasets</span><span class="o">.</span><span class="n">mnist</span><span class="o">.</span><span class="n">load_data</span><span class="p">()</span> <span class="n">x_train</span> <span class="o">=</span> <span class="n">x_train</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">784</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">"float32"</span><span class="p">)</span> <span class="o">/</span> <span class="mf">255.0</span> <span class="n">x_test</span> <span class="o">=</span> <span class="n">x_test</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">784</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">"float32"</span><span class="p">)</span> <span class="o">/</span> <span class="mf">255.0</span> <span class="c1"># Limit the data to 1000 samples</span> <span class="n">x_train</span> <span class="o">=</span> <span class="n">x_train</span><span class="p">[:</span><span class="mi">1000</span><span class="p">]</span> <span class="n">y_train</span> <span class="o">=</span> <span class="n">y_train</span><span class="p">[:</span><span class="mi">1000</span><span class="p">]</span> <span class="n">x_test</span> <span class="o">=</span> <span class="n">x_test</span><span class="p">[:</span><span class="mi">1000</span><span class="p">]</span> <span class="n">y_test</span> <span class="o">=</span> <span class="n">y_test</span><span class="p">[:</span><span class="mi">1000</span><span class="p">]</span> </code></pre></div> <p>Now, define a simple custom callback that logs:</p> <ul> <li>When <code>fit</code>/<code>evaluate</code>/<code>predict</code> starts & ends</li> <li>When each epoch starts & ends</li> <li>When each training batch starts & ends</li> <li>When each evaluation (test) batch starts & ends</li> <li>When each inference (prediction) batch starts & ends</li> </ul> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">CustomCallback</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">callbacks</span><span class="o">.</span><span class="n">Callback</span><span class="p">):</span> <span class="k">def</span> <span class="nf">on_train_begin</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logs</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="n">keys</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">logs</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Starting training; got log keys: </span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">keys</span><span class="p">))</span> <span class="k">def</span> <span class="nf">on_train_end</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logs</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="n">keys</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">logs</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Stop training; got log keys: </span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">keys</span><span class="p">))</span> <span class="k">def</span> <span class="nf">on_epoch_begin</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">epoch</span><span class="p">,</span> <span class="n">logs</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="n">keys</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">logs</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Start epoch </span><span class="si">{}</span><span class="s2"> of training; got log keys: </span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">epoch</span><span class="p">,</span> <span class="n">keys</span><span class="p">))</span> <span class="k">def</span> <span class="nf">on_epoch_end</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">epoch</span><span class="p">,</span> <span class="n">logs</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="n">keys</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">logs</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"End epoch </span><span class="si">{}</span><span class="s2"> of training; got log keys: </span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">epoch</span><span class="p">,</span> <span class="n">keys</span><span class="p">))</span> <span class="k">def</span> <span class="nf">on_test_begin</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logs</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="n">keys</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">logs</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Start testing; got log keys: </span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">keys</span><span class="p">))</span> <span class="k">def</span> <span class="nf">on_test_end</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logs</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="n">keys</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">logs</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Stop testing; got log keys: </span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">keys</span><span class="p">))</span> <span class="k">def</span> <span class="nf">on_predict_begin</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logs</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="n">keys</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">logs</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Start predicting; got log keys: </span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">keys</span><span class="p">))</span> <span class="k">def</span> <span class="nf">on_predict_end</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logs</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="n">keys</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">logs</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Stop predicting; got log keys: </span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">keys</span><span class="p">))</span> <span class="k">def</span> <span class="nf">on_train_batch_begin</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">,</span> <span class="n">logs</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="n">keys</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">logs</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"...Training: start of batch </span><span class="si">{}</span><span class="s2">; got log keys: </span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">batch</span><span class="p">,</span> <span class="n">keys</span><span class="p">))</span> <span class="k">def</span> <span class="nf">on_train_batch_end</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">,</span> <span class="n">logs</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="n">keys</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">logs</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"...Training: end of batch </span><span class="si">{}</span><span class="s2">; got log keys: </span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">batch</span><span class="p">,</span> <span class="n">keys</span><span class="p">))</span> <span class="k">def</span> <span class="nf">on_test_batch_begin</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">,</span> <span class="n">logs</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="n">keys</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">logs</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"...Evaluating: start of batch </span><span class="si">{}</span><span class="s2">; got log keys: </span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">batch</span><span class="p">,</span> <span class="n">keys</span><span class="p">))</span> <span class="k">def</span> <span class="nf">on_test_batch_end</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">,</span> <span class="n">logs</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="n">keys</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">logs</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"...Evaluating: end of batch </span><span class="si">{}</span><span class="s2">; got log keys: </span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">batch</span><span class="p">,</span> <span class="n">keys</span><span class="p">))</span> <span class="k">def</span> <span class="nf">on_predict_batch_begin</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">,</span> <span class="n">logs</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="n">keys</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">logs</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"...Predicting: start of batch </span><span class="si">{}</span><span class="s2">; got log keys: </span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">batch</span><span class="p">,</span> <span class="n">keys</span><span class="p">))</span> <span class="k">def</span> <span class="nf">on_predict_batch_end</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">,</span> <span class="n">logs</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="n">keys</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">logs</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"...Predicting: end of batch </span><span class="si">{}</span><span class="s2">; got log keys: </span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">batch</span><span class="p">,</span> <span class="n">keys</span><span class="p">))</span> </code></pre></div> <p>Let's try it out:</p> <div class="codehilite"><pre><span></span><code><span class="n">model</span> <span class="o">=</span> <span class="n">get_model</span><span class="p">()</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span> <span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">128</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">verbose</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">validation_split</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span> <span class="n">callbacks</span><span class="o">=</span><span class="p">[</span><span class="n">CustomCallback</span><span class="p">()],</span> <span class="p">)</span> <span class="n">res</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span> <span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">128</span><span class="p">,</span> <span class="n">verbose</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">callbacks</span><span class="o">=</span><span class="p">[</span><span class="n">CustomCallback</span><span class="p">()]</span> <span class="p">)</span> <span class="n">res</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">x_test</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">128</span><span class="p">,</span> <span class="n">callbacks</span><span class="o">=</span><span class="p">[</span><span class="n">CustomCallback</span><span class="p">()])</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Starting training; got log keys: [] Start epoch 0 of training; got log keys: [] ...Training: start of batch 0; got log keys: [] ...Training: end of batch 0; got log keys: ['loss', 'mean_absolute_error'] ...Training: start of batch 1; got log keys: [] ...Training: end of batch 1; got log keys: ['loss', 'mean_absolute_error'] ...Training: start of batch 2; got log keys: [] ...Training: end of batch 2; got log keys: ['loss', 'mean_absolute_error'] ...Training: start of batch 3; got log keys: [] ...Training: end of batch 3; got log keys: ['loss', 'mean_absolute_error'] Start testing; got log keys: [] ...Evaluating: start of batch 0; got log keys: [] ...Evaluating: end of batch 0; got log keys: ['loss', 'mean_absolute_error'] ...Evaluating: start of batch 1; got log keys: [] ...Evaluating: end of batch 1; got log keys: ['loss', 'mean_absolute_error'] ...Evaluating: start of batch 2; got log keys: [] ...Evaluating: end of batch 2; got log keys: ['loss', 'mean_absolute_error'] ...Evaluating: start of batch 3; got log keys: [] ...Evaluating: end of batch 3; got log keys: ['loss', 'mean_absolute_error'] Stop testing; got log keys: ['loss', 'mean_absolute_error'] End epoch 0 of training; got log keys: ['loss', 'mean_absolute_error', 'val_loss', 'val_mean_absolute_error'] Stop training; got log keys: ['loss', 'mean_absolute_error', 'val_loss', 'val_mean_absolute_error'] Start testing; got log keys: [] ...Evaluating: start of batch 0; got log keys: [] ...Evaluating: end of batch 0; got log keys: ['loss', 'mean_absolute_error'] ...Evaluating: start of batch 1; got log keys: [] ...Evaluating: end of batch 1; got log keys: ['loss', 'mean_absolute_error'] ...Evaluating: start of batch 2; got log keys: [] ...Evaluating: end of batch 2; got log keys: ['loss', 'mean_absolute_error'] ...Evaluating: start of batch 3; got log keys: [] ...Evaluating: end of batch 3; got log keys: ['loss', 'mean_absolute_error'] ...Evaluating: start of batch 4; got log keys: [] ...Evaluating: end of batch 4; got log keys: ['loss', 'mean_absolute_error'] ...Evaluating: start of batch 5; got log keys: [] ...Evaluating: end of batch 5; got log keys: ['loss', 'mean_absolute_error'] ...Evaluating: start of batch 6; got log keys: [] ...Evaluating: end of batch 6; got log keys: ['loss', 'mean_absolute_error'] ...Evaluating: start of batch 7; got log keys: [] ...Evaluating: end of batch 7; got log keys: ['loss', 'mean_absolute_error'] Stop testing; got log keys: ['loss', 'mean_absolute_error'] Start predicting; got log keys: [] ...Predicting: start of batch 0; got log keys: [] ...Predicting: end of batch 0; got log keys: ['outputs'] 1/8 ━━[37m━━━━━━━━━━━━━━━━━━ 0s 13ms/step...Predicting: start of batch 1; got log keys: [] ...Predicting: end of batch 1; got log keys: ['outputs'] ...Predicting: start of batch 2; got log keys: [] ...Predicting: end of batch 2; got log keys: ['outputs'] ...Predicting: start of batch 3; got log keys: [] ...Predicting: end of batch 3; got log keys: ['outputs'] ...Predicting: start of batch 4; got log keys: [] ...Predicting: end of batch 4; got log keys: ['outputs'] ...Predicting: start of batch 5; got log keys: [] ...Predicting: end of batch 5; got log keys: ['outputs'] ...Predicting: start of batch 6; got log keys: [] ...Predicting: end of batch 6; got log keys: ['outputs'] ...Predicting: start of batch 7; got log keys: [] ...Predicting: end of batch 7; got log keys: ['outputs'] Stop predicting; got log keys: [] 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step </code></pre></div> </div> <h3 id="usage-of-logs-dict">Usage of <code>logs</code> dict</h3> <p>The <code>logs</code> dict contains the loss value, and all the metrics at the end of a batch or epoch. Example includes the loss and mean absolute error.</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">LossAndErrorPrintingCallback</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">callbacks</span><span class="o">.</span><span class="n">Callback</span><span class="p">):</span> <span class="k">def</span> <span class="nf">on_train_batch_end</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">,</span> <span class="n">logs</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="nb">print</span><span class="p">(</span> <span class="s2">"Up to batch </span><span class="si">{}</span><span class="s2">, the average loss is </span><span class="si">{:7.2f}</span><span class="s2">."</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">batch</span><span class="p">,</span> <span class="n">logs</span><span class="p">[</span><span class="s2">"loss"</span><span class="p">])</span> <span class="p">)</span> <span class="k">def</span> <span class="nf">on_test_batch_end</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch</span><span class="p">,</span> <span class="n">logs</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="nb">print</span><span class="p">(</span> <span class="s2">"Up to batch </span><span class="si">{}</span><span class="s2">, the average loss is </span><span class="si">{:7.2f}</span><span class="s2">."</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">batch</span><span class="p">,</span> <span class="n">logs</span><span class="p">[</span><span class="s2">"loss"</span><span class="p">])</span> <span class="p">)</span> <span class="k">def</span> <span class="nf">on_epoch_end</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">epoch</span><span class="p">,</span> <span class="n">logs</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="nb">print</span><span class="p">(</span> <span class="s2">"The average loss for epoch </span><span class="si">{}</span><span class="s2"> is </span><span class="si">{:7.2f}</span><span class="s2"> "</span> <span class="s2">"and mean absolute error is </span><span class="si">{:7.2f}</span><span class="s2">."</span><span class="o">.</span><span class="n">format</span><span class="p">(</span> <span class="n">epoch</span><span class="p">,</span> <span class="n">logs</span><span class="p">[</span><span class="s2">"loss"</span><span class="p">],</span> <span class="n">logs</span><span class="p">[</span><span class="s2">"mean_absolute_error"</span><span class="p">]</span> <span class="p">)</span> <span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">get_model</span><span class="p">()</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span> <span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">128</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">verbose</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">callbacks</span><span class="o">=</span><span class="p">[</span><span class="n">LossAndErrorPrintingCallback</span><span class="p">()],</span> <span class="p">)</span> <span class="n">res</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span> <span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">128</span><span class="p">,</span> <span class="n">verbose</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">callbacks</span><span class="o">=</span><span class="p">[</span><span class="n">LossAndErrorPrintingCallback</span><span class="p">()],</span> <span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Up to batch 0, the average loss is 29.25. Up to batch 1, the average loss is 485.36. Up to batch 2, the average loss is 330.94. Up to batch 3, the average loss is 250.62. Up to batch 4, the average loss is 202.20. Up to batch 5, the average loss is 169.51. Up to batch 6, the average loss is 145.98. Up to batch 7, the average loss is 128.48. The average loss for epoch 0 is 128.48 and mean absolute error is 6.01. Up to batch 0, the average loss is 5.10. Up to batch 1, the average loss is 4.80. Up to batch 2, the average loss is 4.96. Up to batch 3, the average loss is 4.96. Up to batch 4, the average loss is 4.82. Up to batch 5, the average loss is 4.69. Up to batch 6, the average loss is 4.51. Up to batch 7, the average loss is 4.53. The average loss for epoch 1 is 4.53 and mean absolute error is 1.72. Up to batch 0, the average loss is 5.08. Up to batch 1, the average loss is 4.66. Up to batch 2, the average loss is 4.64. Up to batch 3, the average loss is 4.72. Up to batch 4, the average loss is 4.82. Up to batch 5, the average loss is 4.83. Up to batch 6, the average loss is 4.77. Up to batch 7, the average loss is 4.72. </code></pre></div> </div> <hr /> <h2 id="usage-of-selfmodel-attribute">Usage of <code>self.model</code> attribute</h2> <p>In addition to receiving log information when one of their methods is called, callbacks have access to the model associated with the current round of training/evaluation/inference: <code>self.model</code>.</p> <p>Here are a few of the things you can do with <code>self.model</code> in a callback:</p> <ul> <li>Set <code>self.model.stop_training = True</code> to immediately interrupt training.</li> <li>Mutate hyperparameters of the optimizer (available as <code>self.model.optimizer</code>), such as <code>self.model.optimizer.learning_rate</code>.</li> <li>Save the model at period intervals.</li> <li>Record the output of <code>model.predict()</code> on a few test samples at the end of each epoch, to use as a sanity check during training.</li> <li>Extract visualizations of intermediate features at the end of each epoch, to monitor what the model is learning over time.</li> <li>etc.</li> </ul> <p>Let's see this in action in a couple of examples.</p> <hr /> <h2 id="examples-of-keras-callback-applications">Examples of Keras callback applications</h2> <h3 id="early-stopping-at-minimum-loss">Early stopping at minimum loss</h3> <p>This first example shows the creation of a <code>Callback</code> that stops training when the minimum of loss has been reached, by setting the attribute <code>self.model.stop_training</code> (boolean). Optionally, you can provide an argument <code>patience</code> to specify how many epochs we should wait before stopping after having reached a local minimum.</p> <p><a href="/api/callbacks/early_stopping#earlystopping-class"><code>keras.callbacks.EarlyStopping</code></a> provides a more complete and general implementation.</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">EarlyStoppingAtMinLoss</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">callbacks</span><span class="o">.</span><span class="n">Callback</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Stop training when the loss is at its min, i.e. the loss stops decreasing.</span> <span class="sd"> Arguments:</span> <span class="sd"> patience: Number of epochs to wait after min has been hit. After this</span> <span class="sd"> number of no improvement, training stops.</span> <span class="sd"> """</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">patience</span><span class="o">=</span><span class="mi">0</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">patience</span> <span class="o">=</span> <span class="n">patience</span> <span class="c1"># best_weights to store the weights at which the minimum loss occurs.</span> <span class="bp">self</span><span class="o">.</span><span class="n">best_weights</span> <span class="o">=</span> <span class="kc">None</span> <span class="k">def</span> <span class="nf">on_train_begin</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logs</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="c1"># The number of epoch it has waited when loss is no longer minimum.</span> <span class="bp">self</span><span class="o">.</span><span class="n">wait</span> <span class="o">=</span> <span class="mi">0</span> <span class="c1"># The epoch the training stops at.</span> <span class="bp">self</span><span class="o">.</span><span class="n">stopped_epoch</span> <span class="o">=</span> <span class="mi">0</span> <span class="c1"># Initialize the best as infinity.</span> <span class="bp">self</span><span class="o">.</span><span class="n">best</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">inf</span> <span class="k">def</span> <span class="nf">on_epoch_end</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">epoch</span><span class="p">,</span> <span class="n">logs</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="n">current</span> <span class="o">=</span> <span class="n">logs</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="s2">"loss"</span><span class="p">)</span> <span class="k">if</span> <span class="n">np</span><span class="o">.</span><span class="n">less</span><span class="p">(</span><span class="n">current</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">best</span><span class="p">):</span> <span class="bp">self</span><span class="o">.</span><span class="n">best</span> <span class="o">=</span> <span class="n">current</span> <span class="bp">self</span><span class="o">.</span><span class="n">wait</span> <span class="o">=</span> <span class="mi">0</span> <span class="c1"># Record the best weights if current results is better (less).</span> <span class="bp">self</span><span class="o">.</span><span class="n">best_weights</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">get_weights</span><span class="p">()</span> <span class="k">else</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">wait</span> <span class="o">+=</span> <span class="mi">1</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">wait</span> <span class="o">>=</span> <span class="bp">self</span><span class="o">.</span><span class="n">patience</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">stopped_epoch</span> <span class="o">=</span> <span class="n">epoch</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">stop_training</span> <span class="o">=</span> <span class="kc">True</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Restoring model weights from the end of the best epoch."</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">set_weights</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">best_weights</span><span class="p">)</span> <span class="k">def</span> <span class="nf">on_train_end</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">logs</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">stopped_epoch</span> <span class="o">></span> <span class="mi">0</span><span class="p">:</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Epoch </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">stopped_epoch</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="mi">1</span><span class="si">}</span><span class="s2">: early stopping"</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">get_model</span><span class="p">()</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span> <span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">30</span><span class="p">,</span> <span class="n">verbose</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">callbacks</span><span class="o">=</span><span class="p">[</span><span class="n">LossAndErrorPrintingCallback</span><span class="p">(),</span> <span class="n">EarlyStoppingAtMinLoss</span><span class="p">()],</span> <span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Up to batch 0, the average loss is 25.57. Up to batch 1, the average loss is 471.66. Up to batch 2, the average loss is 322.55. Up to batch 3, the average loss is 243.88. Up to batch 4, the average loss is 196.53. Up to batch 5, the average loss is 165.02. Up to batch 6, the average loss is 142.34. Up to batch 7, the average loss is 125.17. Up to batch 8, the average loss is 111.83. Up to batch 9, the average loss is 101.35. Up to batch 10, the average loss is 92.60. Up to batch 11, the average loss is 85.16. Up to batch 12, the average loss is 79.02. Up to batch 13, the average loss is 73.71. Up to batch 14, the average loss is 69.23. Up to batch 15, the average loss is 65.26. The average loss for epoch 0 is 65.26 and mean absolute error is 3.89. Up to batch 0, the average loss is 3.92. Up to batch 1, the average loss is 4.34. Up to batch 2, the average loss is 5.39. Up to batch 3, the average loss is 6.58. Up to batch 4, the average loss is 10.55. Up to batch 5, the average loss is 19.29. Up to batch 6, the average loss is 31.58. Up to batch 7, the average loss is 38.20. Up to batch 8, the average loss is 41.96. Up to batch 9, the average loss is 41.30. Up to batch 10, the average loss is 39.31. Up to batch 11, the average loss is 37.09. Up to batch 12, the average loss is 35.08. Up to batch 13, the average loss is 33.27. Up to batch 14, the average loss is 31.54. Up to batch 15, the average loss is 30.00. The average loss for epoch 1 is 30.00 and mean absolute error is 4.23. Up to batch 0, the average loss is 5.70. Up to batch 1, the average loss is 6.90. Up to batch 2, the average loss is 7.74. Up to batch 3, the average loss is 8.85. Up to batch 4, the average loss is 12.53. Up to batch 5, the average loss is 21.55. Up to batch 6, the average loss is 35.70. Up to batch 7, the average loss is 44.16. Up to batch 8, the average loss is 44.82. Up to batch 9, the average loss is 43.07. Up to batch 10, the average loss is 40.51. Up to batch 11, the average loss is 38.44. Up to batch 12, the average loss is 36.69. Up to batch 13, the average loss is 34.77. Up to batch 14, the average loss is 32.97. Up to batch 15, the average loss is 31.32. The average loss for epoch 2 is 31.32 and mean absolute error is 4.39. Restoring model weights from the end of the best epoch. Epoch 3: early stopping <keras.src.callbacks.history.History at 0x1187b7430> </code></pre></div> </div> <h3 id="learning-rate-scheduling">Learning rate scheduling</h3> <p>In this example, we show how a custom Callback can be used to dynamically change the learning rate of the optimizer during the course of training.</p> <p>See <code>callbacks.LearningRateScheduler</code> for a more general implementations.</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">CustomLearningRateScheduler</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">callbacks</span><span class="o">.</span><span class="n">Callback</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Learning rate scheduler which sets the learning rate according to schedule.</span> <span class="sd"> Arguments:</span> <span class="sd"> schedule: a function that takes an epoch index</span> <span class="sd"> (integer, indexed from 0) and current learning rate</span> <span class="sd"> as inputs and returns a new learning rate as output (float).</span> <span class="sd"> """</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">schedule</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">schedule</span> <span class="o">=</span> <span class="n">schedule</span> <span class="k">def</span> <span class="nf">on_epoch_begin</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">epoch</span><span class="p">,</span> <span class="n">logs</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="k">if</span> <span class="ow">not</span> <span class="nb">hasattr</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">optimizer</span><span class="p">,</span> <span class="s2">"learning_rate"</span><span class="p">):</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s1">'Optimizer must have a "learning_rate" attribute.'</span><span class="p">)</span> <span class="c1"># Get the current learning rate from model's optimizer.</span> <span class="n">lr</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">learning_rate</span> <span class="c1"># Call schedule function to get the scheduled learning rate.</span> <span class="n">scheduled_lr</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">schedule</span><span class="p">(</span><span class="n">epoch</span><span class="p">,</span> <span class="n">lr</span><span class="p">)</span> <span class="c1"># Set the value back to the optimizer before this epoch starts</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">learning_rate</span> <span class="o">=</span> <span class="n">scheduled_lr</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="se">\n</span><span class="s2">Epoch </span><span class="si">{</span><span class="n">epoch</span><span class="si">}</span><span class="s2">: Learning rate is </span><span class="si">{</span><span class="nb">float</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">scheduled_lr</span><span class="p">))</span><span class="si">}</span><span class="s2">."</span><span class="p">)</span> <span class="n">LR_SCHEDULE</span> <span class="o">=</span> <span class="p">[</span> <span class="c1"># (epoch to start, learning rate) tuples</span> <span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mf">0.05</span><span class="p">),</span> <span class="p">(</span><span class="mi">6</span><span class="p">,</span> <span class="mf">0.01</span><span class="p">),</span> <span class="p">(</span><span class="mi">9</span><span class="p">,</span> <span class="mf">0.005</span><span class="p">),</span> <span class="p">(</span><span class="mi">12</span><span class="p">,</span> <span class="mf">0.001</span><span class="p">),</span> <span class="p">]</span> <span class="k">def</span> <span class="nf">lr_schedule</span><span class="p">(</span><span class="n">epoch</span><span class="p">,</span> <span class="n">lr</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Helper function to retrieve the scheduled learning rate based on epoch."""</span> <span class="k">if</span> <span class="n">epoch</span> <span class="o"><</span> <span class="n">LR_SCHEDULE</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span> <span class="ow">or</span> <span class="n">epoch</span> <span class="o">></span> <span class="n">LR_SCHEDULE</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">][</span><span class="mi">0</span><span class="p">]:</span> <span class="k">return</span> <span class="n">lr</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">LR_SCHEDULE</span><span class="p">)):</span> <span class="k">if</span> <span class="n">epoch</span> <span class="o">==</span> <span class="n">LR_SCHEDULE</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">0</span><span class="p">]:</span> <span class="k">return</span> <span class="n">LR_SCHEDULE</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">1</span><span class="p">]</span> <span class="k">return</span> <span class="n">lr</span> <span class="n">model</span> <span class="o">=</span> <span class="n">get_model</span><span class="p">()</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span> <span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">15</span><span class="p">,</span> <span class="n">verbose</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">callbacks</span><span class="o">=</span><span class="p">[</span> <span class="n">LossAndErrorPrintingCallback</span><span class="p">(),</span> <span class="n">CustomLearningRateScheduler</span><span class="p">(</span><span class="n">lr_schedule</span><span class="p">),</span> <span class="p">],</span> <span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 0: Learning rate is 0.10000000149011612. Up to batch 0, the average loss is 27.90. Up to batch 1, the average loss is 439.49. Up to batch 2, the average loss is 302.08. Up to batch 3, the average loss is 228.83. Up to batch 4, the average loss is 184.97. Up to batch 5, the average loss is 155.25. Up to batch 6, the average loss is 134.03. Up to batch 7, the average loss is 118.29. Up to batch 8, the average loss is 105.65. Up to batch 9, the average loss is 95.53. Up to batch 10, the average loss is 87.25. Up to batch 11, the average loss is 80.33. Up to batch 12, the average loss is 74.48. Up to batch 13, the average loss is 69.46. Up to batch 14, the average loss is 65.05. Up to batch 15, the average loss is 61.31. The average loss for epoch 0 is 61.31 and mean absolute error is 3.85. </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 1: Learning rate is 0.10000000149011612. Up to batch 0, the average loss is 57.96. Up to batch 1, the average loss is 55.11. Up to batch 2, the average loss is 52.81. Up to batch 3, the average loss is 51.06. Up to batch 4, the average loss is 50.58. Up to batch 5, the average loss is 51.49. Up to batch 6, the average loss is 53.24. Up to batch 7, the average loss is 54.20. Up to batch 8, the average loss is 54.39. Up to batch 9, the average loss is 54.31. Up to batch 10, the average loss is 53.83. Up to batch 11, the average loss is 52.93. Up to batch 12, the average loss is 51.73. Up to batch 13, the average loss is 50.34. Up to batch 14, the average loss is 48.94. Up to batch 15, the average loss is 47.65. The average loss for epoch 1 is 47.65 and mean absolute error is 4.30. </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 2: Learning rate is 0.10000000149011612. Up to batch 0, the average loss is 46.38. Up to batch 1, the average loss is 45.16. Up to batch 2, the average loss is 44.03. Up to batch 3, the average loss is 43.11. Up to batch 4, the average loss is 42.52. Up to batch 5, the average loss is 42.32. Up to batch 6, the average loss is 43.06. Up to batch 7, the average loss is 44.58. Up to batch 8, the average loss is 45.33. Up to batch 9, the average loss is 45.15. Up to batch 10, the average loss is 44.59. Up to batch 11, the average loss is 43.88. Up to batch 12, the average loss is 43.17. Up to batch 13, the average loss is 42.40. Up to batch 14, the average loss is 41.74. Up to batch 15, the average loss is 41.19. The average loss for epoch 2 is 41.19 and mean absolute error is 4.27. </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 3: Learning rate is 0.05. Up to batch 0, the average loss is 40.85. Up to batch 1, the average loss is 40.11. Up to batch 2, the average loss is 39.38. Up to batch 3, the average loss is 38.69. Up to batch 4, the average loss is 38.01. Up to batch 5, the average loss is 37.38. Up to batch 6, the average loss is 36.77. Up to batch 7, the average loss is 36.18. Up to batch 8, the average loss is 35.61. Up to batch 9, the average loss is 35.08. Up to batch 10, the average loss is 34.54. Up to batch 11, the average loss is 34.04. Up to batch 12, the average loss is 33.56. Up to batch 13, the average loss is 33.08. Up to batch 14, the average loss is 32.64. Up to batch 15, the average loss is 32.25. The average loss for epoch 3 is 32.25 and mean absolute error is 3.64. </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 4: Learning rate is 0.05000000074505806. Up to batch 0, the average loss is 31.83. Up to batch 1, the average loss is 31.42. Up to batch 2, the average loss is 31.05. Up to batch 3, the average loss is 30.72. Up to batch 4, the average loss is 30.49. Up to batch 5, the average loss is 30.37. Up to batch 6, the average loss is 30.15. Up to batch 7, the average loss is 29.94. Up to batch 8, the average loss is 29.75. Up to batch 9, the average loss is 29.56. Up to batch 10, the average loss is 29.27. Up to batch 11, the average loss is 28.96. Up to batch 12, the average loss is 28.67. Up to batch 13, the average loss is 28.39. Up to batch 14, the average loss is 28.11. Up to batch 15, the average loss is 27.80. The average loss for epoch 4 is 27.80 and mean absolute error is 3.43. </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 5: Learning rate is 0.05000000074505806. Up to batch 0, the average loss is 27.51. Up to batch 1, the average loss is 27.25. Up to batch 2, the average loss is 27.05. Up to batch 3, the average loss is 26.88. Up to batch 4, the average loss is 26.76. Up to batch 5, the average loss is 26.60. Up to batch 6, the average loss is 26.44. Up to batch 7, the average loss is 26.25. Up to batch 8, the average loss is 26.08. Up to batch 9, the average loss is 25.89. Up to batch 10, the average loss is 25.71. Up to batch 11, the average loss is 25.48. Up to batch 12, the average loss is 25.26. Up to batch 13, the average loss is 25.03. Up to batch 14, the average loss is 24.81. Up to batch 15, the average loss is 24.58. The average loss for epoch 5 is 24.58 and mean absolute error is 3.25. </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 6: Learning rate is 0.01. Up to batch 0, the average loss is 24.36. Up to batch 1, the average loss is 24.14. Up to batch 2, the average loss is 23.93. Up to batch 3, the average loss is 23.71. Up to batch 4, the average loss is 23.52. Up to batch 5, the average loss is 23.32. Up to batch 6, the average loss is 23.12. Up to batch 7, the average loss is 22.93. Up to batch 8, the average loss is 22.74. Up to batch 9, the average loss is 22.55. Up to batch 10, the average loss is 22.37. Up to batch 11, the average loss is 22.19. Up to batch 12, the average loss is 22.01. Up to batch 13, the average loss is 21.83. Up to batch 14, the average loss is 21.67. Up to batch 15, the average loss is 21.50. The average loss for epoch 6 is 21.50 and mean absolute error is 2.98. </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 7: Learning rate is 0.009999999776482582. Up to batch 0, the average loss is 21.33. Up to batch 1, the average loss is 21.17. Up to batch 2, the average loss is 21.01. Up to batch 3, the average loss is 20.85. Up to batch 4, the average loss is 20.71. Up to batch 5, the average loss is 20.57. Up to batch 6, the average loss is 20.41. Up to batch 7, the average loss is 20.27. Up to batch 8, the average loss is 20.13. Up to batch 9, the average loss is 19.98. Up to batch 10, the average loss is 19.83. Up to batch 11, the average loss is 19.69. Up to batch 12, the average loss is 19.57. Up to batch 13, the average loss is 19.44. Up to batch 14, the average loss is 19.32. Up to batch 15, the average loss is 19.19. The average loss for epoch 7 is 19.19 and mean absolute error is 2.77. </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 8: Learning rate is 0.009999999776482582. Up to batch 0, the average loss is 19.07. Up to batch 1, the average loss is 18.95. Up to batch 2, the average loss is 18.83. Up to batch 3, the average loss is 18.70. Up to batch 4, the average loss is 18.58. Up to batch 5, the average loss is 18.46. Up to batch 6, the average loss is 18.35. Up to batch 7, the average loss is 18.24. Up to batch 8, the average loss is 18.12. Up to batch 9, the average loss is 18.01. Up to batch 10, the average loss is 17.90. Up to batch 11, the average loss is 17.79. Up to batch 12, the average loss is 17.68. Up to batch 13, the average loss is 17.58. Up to batch 14, the average loss is 17.48. Up to batch 15, the average loss is 17.38. The average loss for epoch 8 is 17.38 and mean absolute error is 2.61. </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 9: Learning rate is 0.005. Up to batch 0, the average loss is 17.28. Up to batch 1, the average loss is 17.18. Up to batch 2, the average loss is 17.08. Up to batch 3, the average loss is 16.99. Up to batch 4, the average loss is 16.90. Up to batch 5, the average loss is 16.80. Up to batch 6, the average loss is 16.71. Up to batch 7, the average loss is 16.62. Up to batch 8, the average loss is 16.53. Up to batch 9, the average loss is 16.44. Up to batch 10, the average loss is 16.35. Up to batch 11, the average loss is 16.26. Up to batch 12, the average loss is 16.17. Up to batch 13, the average loss is 16.09. Up to batch 14, the average loss is 16.00. Up to batch 15, the average loss is 15.92. The average loss for epoch 9 is 15.92 and mean absolute error is 2.48. </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 10: Learning rate is 0.004999999888241291. Up to batch 0, the average loss is 15.84. Up to batch 1, the average loss is 15.76. Up to batch 2, the average loss is 15.68. Up to batch 3, the average loss is 15.61. Up to batch 4, the average loss is 15.53. Up to batch 5, the average loss is 15.45. Up to batch 6, the average loss is 15.37. Up to batch 7, the average loss is 15.29. Up to batch 8, the average loss is 15.23. Up to batch 9, the average loss is 15.15. Up to batch 10, the average loss is 15.08. Up to batch 11, the average loss is 15.00. Up to batch 12, the average loss is 14.93. Up to batch 13, the average loss is 14.86. Up to batch 14, the average loss is 14.79. Up to batch 15, the average loss is 14.72. The average loss for epoch 10 is 14.72 and mean absolute error is 2.37. </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 11: Learning rate is 0.004999999888241291. Up to batch 0, the average loss is 14.65. Up to batch 1, the average loss is 14.58. Up to batch 2, the average loss is 14.52. Up to batch 3, the average loss is 14.45. Up to batch 4, the average loss is 14.39. Up to batch 5, the average loss is 14.33. Up to batch 6, the average loss is 14.26. Up to batch 7, the average loss is 14.20. Up to batch 8, the average loss is 14.14. Up to batch 9, the average loss is 14.08. Up to batch 10, the average loss is 14.02. Up to batch 11, the average loss is 13.96. Up to batch 12, the average loss is 13.90. Up to batch 13, the average loss is 13.84. Up to batch 14, the average loss is 13.78. Up to batch 15, the average loss is 13.72. The average loss for epoch 11 is 13.72 and mean absolute error is 2.27. </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 12: Learning rate is 0.001. Up to batch 0, the average loss is 13.67. Up to batch 1, the average loss is 13.60. Up to batch 2, the average loss is 13.55. Up to batch 3, the average loss is 13.49. Up to batch 4, the average loss is 13.44. Up to batch 5, the average loss is 13.38. Up to batch 6, the average loss is 13.33. Up to batch 7, the average loss is 13.28. Up to batch 8, the average loss is 13.22. Up to batch 9, the average loss is 13.17. Up to batch 10, the average loss is 13.12. Up to batch 11, the average loss is 13.07. Up to batch 12, the average loss is 13.02. Up to batch 13, the average loss is 12.97. Up to batch 14, the average loss is 12.92. Up to batch 15, the average loss is 12.87. The average loss for epoch 12 is 12.87 and mean absolute error is 2.19. </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 13: Learning rate is 0.0010000000474974513. Up to batch 0, the average loss is 12.82. Up to batch 1, the average loss is 12.77. Up to batch 2, the average loss is 12.72. Up to batch 3, the average loss is 12.68. Up to batch 4, the average loss is 12.63. Up to batch 5, the average loss is 12.58. Up to batch 6, the average loss is 12.53. Up to batch 7, the average loss is 12.49. Up to batch 8, the average loss is 12.45. Up to batch 9, the average loss is 12.40. Up to batch 10, the average loss is 12.35. Up to batch 11, the average loss is 12.30. Up to batch 12, the average loss is 12.26. Up to batch 13, the average loss is 12.22. Up to batch 14, the average loss is 12.17. Up to batch 15, the average loss is 12.13. The average loss for epoch 13 is 12.13 and mean absolute error is 2.12. </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 14: Learning rate is 0.0010000000474974513. Up to batch 0, the average loss is 12.09. Up to batch 1, the average loss is 12.05. Up to batch 2, the average loss is 12.01. Up to batch 3, the average loss is 11.97. Up to batch 4, the average loss is 11.92. Up to batch 5, the average loss is 11.88. Up to batch 6, the average loss is 11.84. Up to batch 7, the average loss is 11.80. Up to batch 8, the average loss is 11.76. Up to batch 9, the average loss is 11.72. Up to batch 10, the average loss is 11.68. Up to batch 11, the average loss is 11.64. Up to batch 12, the average loss is 11.60. Up to batch 13, the average loss is 11.57. Up to batch 14, the average loss is 11.54. Up to batch 15, the average loss is 11.50. The average loss for epoch 14 is 11.50 and mean absolute error is 2.06. <keras.src.callbacks.history.History at 0x168619c60> </code></pre></div> </div> <h3 id="builtin-keras-callbacks">Built-in Keras callbacks</h3> <p>Be sure to check out the existing Keras callbacks by reading the <a href="https://keras.io/api/callbacks/">API docs</a>. Applications include logging to CSV, saving the model, visualizing metrics in TensorBoard, and a lot more!</p> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#writing-your-own-callbacks'>Writing your own callbacks</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#setup'>Setup</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#keras-callbacks-overview'>Keras callbacks overview</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#an-overview-of-callback-methods'>An overview of callback methods</a> </div> <div class='k-outline-depth-3'> <a href='#global-methods'>Global methods</a> </div> <div class='k-outline-depth-3'> <a href='#batchlevel-methods-for-trainingtestingpredicting'>Batch-level methods for training/testing/predicting</a> </div> <div class='k-outline-depth-3'> <a href='#epochlevel-methods-training-only'>Epoch-level methods (training only)</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#a-basic-example'>A basic example</a> </div> <div class='k-outline-depth-3'> <a href='#usage-of-logs-dict'>Usage of <code>logs</code> dict</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#usage-of-selfmodel-attribute'>Usage of <code>self.model</code> attribute</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#examples-of-keras-callback-applications'>Examples of Keras callback applications</a> </div> <div class='k-outline-depth-3'> <a href='#early-stopping-at-minimum-loss'>Early stopping at minimum loss</a> </div> <div class='k-outline-depth-3'> <a href='#learning-rate-scheduling'>Learning rate scheduling</a> </div> <div class='k-outline-depth-3'> <a href='#builtin-keras-callbacks'>Built-in Keras callbacks</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>