CINXE.COM
Customizing what happens in `fit()` with TensorFlow
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/guides/custom_train_step_in_tensorflow/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Customizing what happens in `fit()` with TensorFlow"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Customizing what happens in `fit()` with TensorFlow"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Customizing what happens in `fit()` with TensorFlow</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link active" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-sublink" href="/guides/functional_api/">The Functional API</a> <a class="nav-sublink" href="/guides/sequential_model/">The Sequential model</a> <a class="nav-sublink" href="/guides/making_new_layers_and_models_via_subclassing/">Making new layers & models via subclassing</a> <a class="nav-sublink" href="/guides/training_with_built_in_methods/">Training & evaluation with the built-in methods</a> <a class="nav-sublink" href="/guides/custom_train_step_in_jax/">Customizing `fit()` with JAX</a> <a class="nav-sublink active" href="/guides/custom_train_step_in_tensorflow/">Customizing `fit()` with TensorFlow</a> <a class="nav-sublink" href="/guides/custom_train_step_in_torch/">Customizing `fit()` with PyTorch</a> <a class="nav-sublink" href="/guides/writing_a_custom_training_loop_in_jax/">Writing a custom training loop in JAX</a> <a class="nav-sublink" href="/guides/writing_a_custom_training_loop_in_tensorflow/">Writing a custom training loop in TensorFlow</a> <a class="nav-sublink" href="/guides/writing_a_custom_training_loop_in_torch/">Writing a custom training loop in PyTorch</a> <a class="nav-sublink" href="/guides/serialization_and_saving/">Serialization & saving</a> <a class="nav-sublink" href="/guides/customizing_saving_and_serialization/">Customizing saving & serialization</a> <a class="nav-sublink" href="/guides/writing_your_own_callbacks/">Writing your own callbacks</a> <a class="nav-sublink" href="/guides/transfer_learning/">Transfer learning & fine-tuning</a> <a class="nav-sublink" href="/guides/distributed_training_with_jax/">Distributed training with JAX</a> <a class="nav-sublink" href="/guides/distributed_training_with_tensorflow/">Distributed training with TensorFlow</a> <a class="nav-sublink" href="/guides/distributed_training_with_torch/">Distributed training with PyTorch</a> <a class="nav-sublink" href="/guides/distribution/">Distributed training with Keras 3</a> <a class="nav-sublink" href="/guides/migrating_to_keras_3/">Migrating Keras 2 code to Keras 3</a> <a class="nav-link" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparam Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/guides/'>Developer guides</a> / Customizing what happens in `fit()` with TensorFlow </div> <div class='k-content'> <h1 id="customizing-what-happens-in-fit-with-tensorflow">Customizing what happens in <code>fit()</code> with TensorFlow</h1> <p><strong>Author:</strong> <a href="https://twitter.com/fchollet">fchollet</a><br> <strong>Date created:</strong> 2020/04/15<br> <strong>Last modified:</strong> 2023/06/27<br> <strong>Description:</strong> Overriding the training step of the Model class with TensorFlow.</p> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/guides/ipynb/custom_train_step_in_tensorflow.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/guides/custom_train_step_in_tensorflow.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="introduction">Introduction</h2> <p>When you're doing supervised learning, you can use <code>fit()</code> and everything works smoothly.</p> <p>When you need to take control of every little detail, you can write your own training loop entirely from scratch.</p> <p>But what if you need a custom training algorithm, but you still want to benefit from the convenient features of <code>fit()</code>, such as callbacks, built-in distribution support, or step fusing?</p> <p>A core principle of Keras is <strong>progressive disclosure of complexity</strong>. You should always be able to get into lower-level workflows in a gradual way. You shouldn't fall off a cliff if the high-level functionality doesn't exactly match your use case. You should be able to gain more control over the small details while retaining a commensurate amount of high-level convenience.</p> <p>When you need to customize what <code>fit()</code> does, you should <strong>override the training step function of the <code>Model</code> class</strong>. This is the function that is called by <code>fit()</code> for every batch of data. You will then be able to call <code>fit()</code> as usual – and it will be running your own learning algorithm.</p> <p>Note that this pattern does not prevent you from building models with the Functional API. You can do this whether you're building <code>Sequential</code> models, Functional API models, or subclassed models.</p> <p>Let's see how that works.</p> <hr /> <h2 id="setup">Setup</h2> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">os</span> <span class="c1"># This guide can only be run with the TF backend.</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s2">"KERAS_BACKEND"</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"tensorflow"</span> <span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="nn">tf</span> <span class="kn">import</span> <span class="nn">keras</span> <span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">layers</span> <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> </code></pre></div> <hr /> <h2 id="a-first-simple-example">A first simple example</h2> <p>Let's start from a simple example:</p> <ul> <li>We create a new class that subclasses <a href="/api/models/model#model-class"><code>keras.Model</code></a>.</li> <li>We just override the method <code>train_step(self, data)</code>.</li> <li>We return a dictionary mapping metric names (including the loss) to their current value.</li> </ul> <p>The input argument <code>data</code> is what gets passed to fit as training data:</p> <ul> <li>If you pass NumPy arrays, by calling <code>fit(x, y, ...)</code>, then <code>data</code> will be the tuple <code>(x, y)</code></li> <li>If you pass a <a href="https://www.tensorflow.org/api_docs/python/tf/data/Dataset"><code>tf.data.Dataset</code></a>, by calling <code>fit(dataset, ...)</code>, then <code>data</code> will be what gets yielded by <code>dataset</code> at each batch.</li> </ul> <p>In the body of the <code>train_step()</code> method, we implement a regular training update, similar to what you are already familiar with. Importantly, <strong>we compute the loss via <code>self.compute_loss()</code></strong>, which wraps the loss(es) function(s) that were passed to <code>compile()</code>.</p> <p>Similarly, we call <code>metric.update_state(y, y_pred)</code> on metrics from <code>self.metrics</code>, to update the state of the metrics that were passed in <code>compile()</code>, and we query results from <code>self.metrics</code> at the end to retrieve their current value.</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">CustomModel</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">):</span> <span class="k">def</span> <span class="nf">train_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">):</span> <span class="c1"># Unpack the data. Its structure depends on your model and</span> <span class="c1"># on what you pass to `fit()`.</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">data</span> <span class="k">with</span> <span class="n">tf</span><span class="o">.</span><span class="n">GradientTape</span><span class="p">()</span> <span class="k">as</span> <span class="n">tape</span><span class="p">:</span> <span class="n">y_pred</span> <span class="o">=</span> <span class="bp">self</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="c1"># Forward pass</span> <span class="c1"># Compute the loss value</span> <span class="c1"># (the loss function is configured in `compile()`)</span> <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">compute_loss</span><span class="p">(</span><span class="n">y</span><span class="o">=</span><span class="n">y</span><span class="p">,</span> <span class="n">y_pred</span><span class="o">=</span><span class="n">y_pred</span><span class="p">)</span> <span class="c1"># Compute gradients</span> <span class="n">trainable_vars</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">trainable_variables</span> <span class="n">gradients</span> <span class="o">=</span> <span class="n">tape</span><span class="o">.</span><span class="n">gradient</span><span class="p">(</span><span class="n">loss</span><span class="p">,</span> <span class="n">trainable_vars</span><span class="p">)</span> <span class="c1"># Update weights</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">gradients</span><span class="p">,</span> <span class="n">trainable_vars</span><span class="p">)</span> <span class="c1"># Update metrics (includes the metric that tracks the loss)</span> <span class="k">for</span> <span class="n">metric</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">metrics</span><span class="p">:</span> <span class="k">if</span> <span class="n">metric</span><span class="o">.</span><span class="n">name</span> <span class="o">==</span> <span class="s2">"loss"</span><span class="p">:</span> <span class="n">metric</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">loss</span><span class="p">)</span> <span class="k">else</span><span class="p">:</span> <span class="n">metric</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">)</span> <span class="c1"># Return a dict mapping metric names to current value</span> <span class="k">return</span> <span class="p">{</span><span class="n">m</span><span class="o">.</span><span class="n">name</span><span class="p">:</span> <span class="n">m</span><span class="o">.</span><span class="n">result</span><span class="p">()</span> <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">metrics</span><span class="p">}</span> </code></pre></div> <p>Let's try this out:</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Construct and compile an instance of CustomModel</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">32</span><span class="p">,))</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">1</span><span class="p">)(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">CustomModel</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">optimizer</span><span class="o">=</span><span class="s2">"adam"</span><span class="p">,</span> <span class="n">loss</span><span class="o">=</span><span class="s2">"mse"</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="s2">"mae"</span><span class="p">])</span> <span class="c1"># Just use `fit` as usual</span> <span class="n">x</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">((</span><span class="mi">1000</span><span class="p">,</span> <span class="mi">32</span><span class="p">))</span> <span class="n">y</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">((</span><span class="mi">1000</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 1/3 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - mae: 0.5089 - loss: 0.3778 Epoch 2/3 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 318us/step - mae: 0.3986 - loss: 0.2466 Epoch 3/3 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 372us/step - mae: 0.3848 - loss: 0.2319 WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1699222602.443035 1 device_compiler.h:187] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process. <keras.src.callbacks.history.History at 0x2a5599f00> </code></pre></div> </div> <hr /> <h2 id="going-lowerlevel">Going lower-level</h2> <p>Naturally, you could just skip passing a loss function in <code>compile()</code>, and instead do everything <em>manually</em> in <code>train_step</code>. Likewise for metrics.</p> <p>Here's a lower-level example, that only uses <code>compile()</code> to configure the optimizer:</p> <ul> <li>We start by creating <code>Metric</code> instances to track our loss and a MAE score (in <code>__init__()</code>).</li> <li>We implement a custom <code>train_step()</code> that updates the state of these metrics (by calling <code>update_state()</code> on them), then query them (via <code>result()</code>) to return their current average value, to be displayed by the progress bar and to be pass to any callback.</li> <li>Note that we would need to call <code>reset_states()</code> on our metrics between each epoch! Otherwise calling <code>result()</code> would return an average since the start of training, whereas we usually work with per-epoch averages. Thankfully, the framework can do that for us: just list any metric you want to reset in the <code>metrics</code> property of the model. The model will call <code>reset_states()</code> on any object listed here at the beginning of each <code>fit()</code> epoch or at the beginning of a call to <code>evaluate()</code>.</li> </ul> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">CustomModel</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_tracker</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">Mean</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">"loss"</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">mae_metric</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">MeanAbsoluteError</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">"mae"</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_fn</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">MeanSquaredError</span><span class="p">()</span> <span class="k">def</span> <span class="nf">train_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">):</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">data</span> <span class="k">with</span> <span class="n">tf</span><span class="o">.</span><span class="n">GradientTape</span><span class="p">()</span> <span class="k">as</span> <span class="n">tape</span><span class="p">:</span> <span class="n">y_pred</span> <span class="o">=</span> <span class="bp">self</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="c1"># Forward pass</span> <span class="c1"># Compute our own loss</span> <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_fn</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">)</span> <span class="c1"># Compute gradients</span> <span class="n">trainable_vars</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">trainable_variables</span> <span class="n">gradients</span> <span class="o">=</span> <span class="n">tape</span><span class="o">.</span><span class="n">gradient</span><span class="p">(</span><span class="n">loss</span><span class="p">,</span> <span class="n">trainable_vars</span><span class="p">)</span> <span class="c1"># Update weights</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">gradients</span><span class="p">,</span> <span class="n">trainable_vars</span><span class="p">)</span> <span class="c1"># Compute our own metrics</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_tracker</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">loss</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">mae_metric</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">)</span> <span class="k">return</span> <span class="p">{</span> <span class="s2">"loss"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_tracker</span><span class="o">.</span><span class="n">result</span><span class="p">(),</span> <span class="s2">"mae"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">mae_metric</span><span class="o">.</span><span class="n">result</span><span class="p">(),</span> <span class="p">}</span> <span class="nd">@property</span> <span class="k">def</span> <span class="nf">metrics</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="c1"># We list our `Metric` objects here so that `reset_states()` can be</span> <span class="c1"># called automatically at the start of each epoch</span> <span class="c1"># or at the start of `evaluate()`.</span> <span class="k">return</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">loss_tracker</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mae_metric</span><span class="p">]</span> <span class="c1"># Construct an instance of CustomModel</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">32</span><span class="p">,))</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">1</span><span class="p">)(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">CustomModel</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="p">)</span> <span class="c1"># We don't pass a loss or metrics here.</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">optimizer</span><span class="o">=</span><span class="s2">"adam"</span><span class="p">)</span> <span class="c1"># Just use `fit` as usual -- you can use callbacks, etc.</span> <span class="n">x</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">((</span><span class="mi">1000</span><span class="p">,</span> <span class="mi">32</span><span class="p">))</span> <span class="n">y</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">((</span><span class="mi">1000</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">5</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 1/5 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 4.0292 - mae: 1.9270 Epoch 2/5 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 385us/step - loss: 2.2155 - mae: 1.3920 Epoch 3/5 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 336us/step - loss: 1.1863 - mae: 0.9700 Epoch 4/5 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 373us/step - loss: 0.6510 - mae: 0.6811 Epoch 5/5 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 330us/step - loss: 0.4059 - mae: 0.5094 <keras.src.callbacks.history.History at 0x2a7a02860> </code></pre></div> </div> <hr /> <h2 id="supporting-sampleweight-amp-classweight">Supporting <code>sample_weight</code> & <code>class_weight</code></h2> <p>You may have noticed that our first basic example didn't make any mention of sample weighting. If you want to support the <code>fit()</code> arguments <code>sample_weight</code> and <code>class_weight</code>, you'd simply do the following:</p> <ul> <li>Unpack <code>sample_weight</code> from the <code>data</code> argument</li> <li>Pass it to <code>compute_loss</code> & <code>update_state</code> (of course, you could also just apply it manually if you don't rely on <code>compile()</code> for losses & metrics)</li> <li>That's it.</li> </ul> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">CustomModel</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">):</span> <span class="k">def</span> <span class="nf">train_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">):</span> <span class="c1"># Unpack the data. Its structure depends on your model and</span> <span class="c1"># on what you pass to `fit()`.</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> <span class="o">==</span> <span class="mi">3</span><span class="p">:</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">sample_weight</span> <span class="o">=</span> <span class="n">data</span> <span class="k">else</span><span class="p">:</span> <span class="n">sample_weight</span> <span class="o">=</span> <span class="kc">None</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">data</span> <span class="k">with</span> <span class="n">tf</span><span class="o">.</span><span class="n">GradientTape</span><span class="p">()</span> <span class="k">as</span> <span class="n">tape</span><span class="p">:</span> <span class="n">y_pred</span> <span class="o">=</span> <span class="bp">self</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="c1"># Forward pass</span> <span class="c1"># Compute the loss value.</span> <span class="c1"># The loss function is configured in `compile()`.</span> <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">compute_loss</span><span class="p">(</span> <span class="n">y</span><span class="o">=</span><span class="n">y</span><span class="p">,</span> <span class="n">y_pred</span><span class="o">=</span><span class="n">y_pred</span><span class="p">,</span> <span class="n">sample_weight</span><span class="o">=</span><span class="n">sample_weight</span><span class="p">,</span> <span class="p">)</span> <span class="c1"># Compute gradients</span> <span class="n">trainable_vars</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">trainable_variables</span> <span class="n">gradients</span> <span class="o">=</span> <span class="n">tape</span><span class="o">.</span><span class="n">gradient</span><span class="p">(</span><span class="n">loss</span><span class="p">,</span> <span class="n">trainable_vars</span><span class="p">)</span> <span class="c1"># Update weights</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">gradients</span><span class="p">,</span> <span class="n">trainable_vars</span><span class="p">)</span> <span class="c1"># Update the metrics.</span> <span class="c1"># Metrics are configured in `compile()`.</span> <span class="k">for</span> <span class="n">metric</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">metrics</span><span class="p">:</span> <span class="k">if</span> <span class="n">metric</span><span class="o">.</span><span class="n">name</span> <span class="o">==</span> <span class="s2">"loss"</span><span class="p">:</span> <span class="n">metric</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">loss</span><span class="p">)</span> <span class="k">else</span><span class="p">:</span> <span class="n">metric</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">,</span> <span class="n">sample_weight</span><span class="o">=</span><span class="n">sample_weight</span><span class="p">)</span> <span class="c1"># Return a dict mapping metric names to current value.</span> <span class="c1"># Note that it will include the loss (tracked in self.metrics).</span> <span class="k">return</span> <span class="p">{</span><span class="n">m</span><span class="o">.</span><span class="n">name</span><span class="p">:</span> <span class="n">m</span><span class="o">.</span><span class="n">result</span><span class="p">()</span> <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">metrics</span><span class="p">}</span> <span class="c1"># Construct and compile an instance of CustomModel</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">32</span><span class="p">,))</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">1</span><span class="p">)(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">CustomModel</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">optimizer</span><span class="o">=</span><span class="s2">"adam"</span><span class="p">,</span> <span class="n">loss</span><span class="o">=</span><span class="s2">"mse"</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="s2">"mae"</span><span class="p">])</span> <span class="c1"># You can now use sample_weight argument</span> <span class="n">x</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">((</span><span class="mi">1000</span><span class="p">,</span> <span class="mi">32</span><span class="p">))</span> <span class="n">y</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">((</span><span class="mi">1000</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="n">sw</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">((</span><span class="mi">1000</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">sample_weight</span><span class="o">=</span><span class="n">sw</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 1/3 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - mae: 0.4228 - loss: 0.1420 Epoch 2/3 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 449us/step - mae: 0.3751 - loss: 0.1058 Epoch 3/3 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 337us/step - mae: 0.3478 - loss: 0.0951 <keras.src.callbacks.history.History at 0x2a7491780> </code></pre></div> </div> <hr /> <h2 id="providing-your-own-evaluation-step">Providing your own evaluation step</h2> <p>What if you want to do the same for calls to <code>model.evaluate()</code>? Then you would override <code>test_step</code> in exactly the same way. Here's what it looks like:</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">CustomModel</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">):</span> <span class="k">def</span> <span class="nf">test_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">):</span> <span class="c1"># Unpack the data</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">data</span> <span class="c1"># Compute predictions</span> <span class="n">y_pred</span> <span class="o">=</span> <span class="bp">self</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> <span class="c1"># Updates the metrics tracking the loss</span> <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">compute_loss</span><span class="p">(</span><span class="n">y</span><span class="o">=</span><span class="n">y</span><span class="p">,</span> <span class="n">y_pred</span><span class="o">=</span><span class="n">y_pred</span><span class="p">)</span> <span class="c1"># Update the metrics.</span> <span class="k">for</span> <span class="n">metric</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">metrics</span><span class="p">:</span> <span class="k">if</span> <span class="n">metric</span><span class="o">.</span><span class="n">name</span> <span class="o">==</span> <span class="s2">"loss"</span><span class="p">:</span> <span class="n">metric</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">loss</span><span class="p">)</span> <span class="k">else</span><span class="p">:</span> <span class="n">metric</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">)</span> <span class="c1"># Return a dict mapping metric names to current value.</span> <span class="c1"># Note that it will include the loss (tracked in self.metrics).</span> <span class="k">return</span> <span class="p">{</span><span class="n">m</span><span class="o">.</span><span class="n">name</span><span class="p">:</span> <span class="n">m</span><span class="o">.</span><span class="n">result</span><span class="p">()</span> <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">metrics</span><span class="p">}</span> <span class="c1"># Construct an instance of CustomModel</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">32</span><span class="p">,))</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">1</span><span class="p">)(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">CustomModel</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">loss</span><span class="o">=</span><span class="s2">"mse"</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="s2">"mae"</span><span class="p">])</span> <span class="c1"># Evaluate with our custom test_step</span> <span class="n">x</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">((</span><span class="mi">1000</span><span class="p">,</span> <span class="mi">32</span><span class="p">))</span> <span class="n">y</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">((</span><span class="mi">1000</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="n">model</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 32/32 ━━━━━━━━━━━━━━━━━━━━ 0s 927us/step - mae: 0.8518 - loss: 0.9166 [0.912325382232666, 0.8567370176315308] </code></pre></div> </div> <hr /> <h2 id="wrapping-up-an-endtoend-gan-example">Wrapping up: an end-to-end GAN example</h2> <p>Let's walk through an end-to-end example that leverages everything you just learned.</p> <p>Let's consider:</p> <ul> <li>A generator network meant to generate 28x28x1 images.</li> <li>A discriminator network meant to classify 28x28x1 images into two classes ("fake" and "real").</li> <li>One optimizer for each.</li> <li>A loss function to train the discriminator.</li> </ul> <div class="codehilite"><pre><span></span><code><span class="c1"># Create the discriminator</span> <span class="n">discriminator</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span> <span class="p">[</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">28</span><span class="p">,</span> <span class="mi">28</span><span class="p">,</span> <span class="mi">1</span><span class="p">)),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">strides</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"same"</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">(</span><span class="n">negative_slope</span><span class="o">=</span><span class="mf">0.2</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">strides</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"same"</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">(</span><span class="n">negative_slope</span><span class="o">=</span><span class="mf">0.2</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">GlobalMaxPooling2D</span><span class="p">(),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="p">],</span> <span class="n">name</span><span class="o">=</span><span class="s2">"discriminator"</span><span class="p">,</span> <span class="p">)</span> <span class="c1"># Create the generator</span> <span class="n">latent_dim</span> <span class="o">=</span> <span class="mi">128</span> <span class="n">generator</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span> <span class="p">[</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">latent_dim</span><span class="p">,)),</span> <span class="c1"># We want to generate 128 coefficients to reshape into a 7x7x128 map</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">7</span> <span class="o">*</span> <span class="mi">7</span> <span class="o">*</span> <span class="mi">128</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">(</span><span class="n">negative_slope</span><span class="o">=</span><span class="mf">0.2</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Reshape</span><span class="p">((</span><span class="mi">7</span><span class="p">,</span> <span class="mi">7</span><span class="p">,</span> <span class="mi">128</span><span class="p">)),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2DTranspose</span><span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">),</span> <span class="n">strides</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"same"</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">(</span><span class="n">negative_slope</span><span class="o">=</span><span class="mf">0.2</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2DTranspose</span><span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">),</span> <span class="n">strides</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"same"</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">(</span><span class="n">negative_slope</span><span class="o">=</span><span class="mf">0.2</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="p">(</span><span class="mi">7</span><span class="p">,</span> <span class="mi">7</span><span class="p">),</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"same"</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"sigmoid"</span><span class="p">),</span> <span class="p">],</span> <span class="n">name</span><span class="o">=</span><span class="s2">"generator"</span><span class="p">,</span> <span class="p">)</span> </code></pre></div> <p>Here's a feature-complete GAN class, overriding <code>compile()</code> to use its own signature, and implementing the entire GAN algorithm in 17 lines in <code>train_step</code>:</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">GAN</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">discriminator</span><span class="p">,</span> <span class="n">generator</span><span class="p">,</span> <span class="n">latent_dim</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator</span> <span class="o">=</span> <span class="n">discriminator</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator</span> <span class="o">=</span> <span class="n">generator</span> <span class="bp">self</span><span class="o">.</span><span class="n">latent_dim</span> <span class="o">=</span> <span class="n">latent_dim</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_loss_tracker</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">Mean</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">"d_loss"</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">g_loss_tracker</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">Mean</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">"g_loss"</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">seed_generator</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">SeedGenerator</span><span class="p">(</span><span class="mi">1337</span><span class="p">)</span> <span class="nd">@property</span> <span class="k">def</span> <span class="nf">metrics</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="k">return</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">d_loss_tracker</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">g_loss_tracker</span><span class="p">]</span> <span class="k">def</span> <span class="nf">compile</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_optimizer</span><span class="p">,</span> <span class="n">g_optimizer</span><span class="p">,</span> <span class="n">loss_fn</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">compile</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_optimizer</span> <span class="o">=</span> <span class="n">d_optimizer</span> <span class="bp">self</span><span class="o">.</span><span class="n">g_optimizer</span> <span class="o">=</span> <span class="n">g_optimizer</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_fn</span> <span class="o">=</span> <span class="n">loss_fn</span> <span class="k">def</span> <span class="nf">train_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">real_images</span><span class="p">):</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">real_images</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">):</span> <span class="n">real_images</span> <span class="o">=</span> <span class="n">real_images</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="c1"># Sample random points in the latent space</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">real_images</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span> <span class="n">random_latent_vectors</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">latent_dim</span><span class="p">),</span> <span class="n">seed</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">seed_generator</span> <span class="p">)</span> <span class="c1"># Decode them to fake images</span> <span class="n">generated_images</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator</span><span class="p">(</span><span class="n">random_latent_vectors</span><span class="p">)</span> <span class="c1"># Combine them with real images</span> <span class="n">combined_images</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">concat</span><span class="p">([</span><span class="n">generated_images</span><span class="p">,</span> <span class="n">real_images</span><span class="p">],</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="c1"># Assemble labels discriminating real from fake images</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">concat</span><span class="p">(</span> <span class="p">[</span><span class="n">tf</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">)),</span> <span class="n">tf</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">))],</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span> <span class="p">)</span> <span class="c1"># Add random noise to the labels - important trick!</span> <span class="n">labels</span> <span class="o">+=</span> <span class="mf">0.05</span> <span class="o">*</span> <span class="n">keras</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span> <span class="n">tf</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">labels</span><span class="p">),</span> <span class="n">seed</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">seed_generator</span> <span class="p">)</span> <span class="c1"># Train the discriminator</span> <span class="k">with</span> <span class="n">tf</span><span class="o">.</span><span class="n">GradientTape</span><span class="p">()</span> <span class="k">as</span> <span class="n">tape</span><span class="p">:</span> <span class="n">predictions</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator</span><span class="p">(</span><span class="n">combined_images</span><span class="p">)</span> <span class="n">d_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_fn</span><span class="p">(</span><span class="n">labels</span><span class="p">,</span> <span class="n">predictions</span><span class="p">)</span> <span class="n">grads</span> <span class="o">=</span> <span class="n">tape</span><span class="o">.</span><span class="n">gradient</span><span class="p">(</span><span class="n">d_loss</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator</span><span class="o">.</span><span class="n">trainable_weights</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_optimizer</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">grads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator</span><span class="o">.</span><span class="n">trainable_weights</span><span class="p">)</span> <span class="c1"># Sample random points in the latent space</span> <span class="n">random_latent_vectors</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">latent_dim</span><span class="p">),</span> <span class="n">seed</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">seed_generator</span> <span class="p">)</span> <span class="c1"># Assemble labels that say "all real images"</span> <span class="n">misleading_labels</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="c1"># Train the generator (note that we should *not* update the weights</span> <span class="c1"># of the discriminator)!</span> <span class="k">with</span> <span class="n">tf</span><span class="o">.</span><span class="n">GradientTape</span><span class="p">()</span> <span class="k">as</span> <span class="n">tape</span><span class="p">:</span> <span class="n">predictions</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">generator</span><span class="p">(</span><span class="n">random_latent_vectors</span><span class="p">))</span> <span class="n">g_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_fn</span><span class="p">(</span><span class="n">misleading_labels</span><span class="p">,</span> <span class="n">predictions</span><span class="p">)</span> <span class="n">grads</span> <span class="o">=</span> <span class="n">tape</span><span class="o">.</span><span class="n">gradient</span><span class="p">(</span><span class="n">g_loss</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator</span><span class="o">.</span><span class="n">trainable_weights</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">g_optimizer</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">grads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator</span><span class="o">.</span><span class="n">trainable_weights</span><span class="p">)</span> <span class="c1"># Update metrics and return their value.</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_loss_tracker</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">d_loss</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">g_loss_tracker</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">g_loss</span><span class="p">)</span> <span class="k">return</span> <span class="p">{</span> <span class="s2">"d_loss"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_loss_tracker</span><span class="o">.</span><span class="n">result</span><span class="p">(),</span> <span class="s2">"g_loss"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">g_loss_tracker</span><span class="o">.</span><span class="n">result</span><span class="p">(),</span> <span class="p">}</span> </code></pre></div> <p>Let's test-drive it:</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Prepare the dataset. We use both the training & test MNIST digits.</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="mi">64</span> <span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="n">_</span><span class="p">),</span> <span class="p">(</span><span class="n">x_test</span><span class="p">,</span> <span class="n">_</span><span class="p">)</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">datasets</span><span class="o">.</span><span class="n">mnist</span><span class="o">.</span><span class="n">load_data</span><span class="p">()</span> <span class="n">all_digits</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">([</span><span class="n">x_train</span><span class="p">,</span> <span class="n">x_test</span><span class="p">])</span> <span class="n">all_digits</span> <span class="o">=</span> <span class="n">all_digits</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">"float32"</span><span class="p">)</span> <span class="o">/</span> <span class="mf">255.0</span> <span class="n">all_digits</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">all_digits</span><span class="p">,</span> <span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">28</span><span class="p">,</span> <span class="mi">28</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">from_tensor_slices</span><span class="p">(</span><span class="n">all_digits</span><span class="p">)</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="n">buffer_size</span><span class="o">=</span><span class="mi">1024</span><span class="p">)</span><span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span> <span class="n">gan</span> <span class="o">=</span> <span class="n">GAN</span><span class="p">(</span><span class="n">discriminator</span><span class="o">=</span><span class="n">discriminator</span><span class="p">,</span> <span class="n">generator</span><span class="o">=</span><span class="n">generator</span><span class="p">,</span> <span class="n">latent_dim</span><span class="o">=</span><span class="n">latent_dim</span><span class="p">)</span> <span class="n">gan</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">d_optimizer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">learning_rate</span><span class="o">=</span><span class="mf">0.0003</span><span class="p">),</span> <span class="n">g_optimizer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">learning_rate</span><span class="o">=</span><span class="mf">0.0003</span><span class="p">),</span> <span class="n">loss_fn</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">BinaryCrossentropy</span><span class="p">(</span><span class="n">from_logits</span><span class="o">=</span><span class="kc">True</span><span class="p">),</span> <span class="p">)</span> <span class="c1"># To limit the execution time, we only train on 100 batches. You can train on</span> <span class="c1"># the entire dataset. You will need about 20 epochs to get nice results.</span> <span class="n">gan</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">dataset</span><span class="o">.</span><span class="n">take</span><span class="p">(</span><span class="mi">100</span><span class="p">),</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 100/100 ━━━━━━━━━━━━━━━━━━━━ 51s 500ms/step - d_loss: 0.5645 - g_loss: 0.7434 <keras.src.callbacks.history.History at 0x14a4f1b10> </code></pre></div> </div> <p>The ideas behind deep learning are simple, so why should their implementation be painful?</p> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#customizing-what-happens-in-fit-with-tensorflow'>Customizing what happens in <code>fit()</code> with TensorFlow</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#setup'>Setup</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#a-first-simple-example'>A first simple example</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#going-lowerlevel'>Going lower-level</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#supporting-sampleweight-amp-classweight'>Supporting <code>sample_weight</code> & <code>class_weight</code></a> </div> <div class='k-outline-depth-2'> ◆ <a href='#providing-your-own-evaluation-step'>Providing your own evaluation step</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#wrapping-up-an-endtoend-gan-example'>Wrapping up: an end-to-end GAN example</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>