CINXE.COM
Losses
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/api/losses/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Losses"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Losses"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Losses</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-link active" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-sublink" href="/api/models/">Models API</a> <a class="nav-sublink" href="/api/layers/">Layers API</a> <a class="nav-sublink" href="/api/callbacks/">Callbacks API</a> <a class="nav-sublink" href="/api/ops/">Ops API</a> <a class="nav-sublink" href="/api/optimizers/">Optimizers</a> <a class="nav-sublink" href="/api/metrics/">Metrics</a> <a class="nav-sublink active" href="/api/losses/">Losses</a> <a class="nav-sublink2" href="/api/losses/probabilistic_losses/">Probabilistic losses</a> <a class="nav-sublink2" href="/api/losses/regression_losses/">Regression losses</a> <a class="nav-sublink2" href="/api/losses/hinge_losses/">Hinge losses for "maximum-margin" classification</a> <a class="nav-sublink" href="/api/data_loading/">Data loading</a> <a class="nav-sublink" href="/api/datasets/">Built-in small datasets</a> <a class="nav-sublink" href="/api/applications/">Keras Applications</a> <a class="nav-sublink" href="/api/mixed_precision/">Mixed precision</a> <a class="nav-sublink" href="/api/distribution/">Multi-device distribution</a> <a class="nav-sublink" href="/api/random/">RNG API</a> <a class="nav-sublink" href="/api/utils/">Utilities</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparam Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/api/'>Keras 3 API documentation</a> / Losses </div> <div class='k-content'> <h1 id="losses">Losses</h1> <p>The purpose of loss functions is to compute the quantity that a model should seek to minimize during training.</p> <h2 id="available-losses">Available losses</h2> <p>Note that all losses are available both via a class handle and via a function handle. The class handles enable you to pass configuration arguments to the constructor (e.g. <code>loss_fn = CategoricalCrossentropy(from_logits=True)</code>), and they perform reduction by default when used in a standalone way (see details below).</p> <h3 id="probabilistic-losses"><a href="/api/losses/probabilistic_losses/">Probabilistic losses</a></h3> <ul> <li><a href="/api/losses/probabilistic_losses/#binarycrossentropy-class">BinaryCrossentropy class</a></li> <li><a href="/api/losses/probabilistic_losses/#binaryfocalcrossentropy-class">BinaryFocalCrossentropy class</a></li> <li><a href="/api/losses/probabilistic_losses/#categoricalcrossentropy-class">CategoricalCrossentropy class</a></li> <li><a href="/api/losses/probabilistic_losses/#categoricalfocalcrossentropy-class">CategoricalFocalCrossentropy class</a></li> <li><a href="/api/losses/probabilistic_losses/#sparsecategoricalcrossentropy-class">SparseCategoricalCrossentropy class</a></li> <li><a href="/api/losses/probabilistic_losses/#poisson-class">Poisson class</a></li> <li><a href="/api/losses/probabilistic_losses/#ctc-class">CTC class</a></li> <li><a href="/api/losses/probabilistic_losses/#kldivergence-class">KLDivergence class</a></li> <li><a href="/api/losses/probabilistic_losses/#binary_crossentropy-function">binary_crossentropy function</a></li> <li><a href="/api/losses/probabilistic_losses/#categorical_crossentropy-function">categorical_crossentropy function</a></li> <li><a href="/api/losses/probabilistic_losses/#sparse_categorical_crossentropy-function">sparse_categorical_crossentropy function</a></li> <li><a href="/api/losses/probabilistic_losses/#poisson-function">poisson function</a></li> <li><a href="/api/losses/probabilistic_losses/#ctc-function">ctc function</a></li> <li><a href="/api/losses/probabilistic_losses/#kl_divergence-function">kl_divergence function</a></li> </ul> <h3 id="regression-losses"><a href="/api/losses/regression_losses/">Regression losses</a></h3> <ul> <li><a href="/api/losses/regression_losses/#meansquarederror-class">MeanSquaredError class</a></li> <li><a href="/api/losses/regression_losses/#meanabsoluteerror-class">MeanAbsoluteError class</a></li> <li><a href="/api/losses/regression_losses/#meanabsolutepercentageerror-class">MeanAbsolutePercentageError class</a></li> <li><a href="/api/losses/regression_losses/#meansquaredlogarithmicerror-class">MeanSquaredLogarithmicError class</a></li> <li><a href="/api/losses/regression_losses/#cosinesimilarity-class">CosineSimilarity class</a></li> <li><a href="/api/losses/regression_losses/#huber-class">Huber class</a></li> <li><a href="/api/losses/regression_losses/#logcosh-class">LogCosh class</a></li> <li><a href="/api/losses/regression_losses/#tversky-class">Tversky class</a></li> <li><a href="/api/losses/regression_losses/#dice-class">Dice class</a></li> <li><a href="/api/losses/regression_losses/#mean_squared_error-function">mean_squared_error function</a></li> <li><a href="/api/losses/regression_losses/#mean_absolute_error-function">mean_absolute_error function</a></li> <li><a href="/api/losses/regression_losses/#mean_absolute_percentage_error-function">mean_absolute_percentage_error function</a></li> <li><a href="/api/losses/regression_losses/#mean_squared_logarithmic_error-function">mean_squared_logarithmic_error function</a></li> <li><a href="/api/losses/regression_losses/#cosine_similarity-function">cosine_similarity function</a></li> <li><a href="/api/losses/regression_losses/#huber-function">huber function</a></li> <li><a href="/api/losses/regression_losses/#log_cosh-function">log_cosh function</a></li> <li><a href="/api/losses/regression_losses/#tversky-function">tversky function</a></li> <li><a href="/api/losses/regression_losses/#dice-function">dice function</a></li> </ul> <h3 id="hinge-losses-for-maximummargin-classification"><a href="/api/losses/hinge_losses/">Hinge losses for "maximum-margin" classification</a></h3> <ul> <li><a href="/api/losses/hinge_losses/#hinge-class">Hinge class</a></li> <li><a href="/api/losses/hinge_losses/#squaredhinge-class">SquaredHinge class</a></li> <li><a href="/api/losses/hinge_losses/#categoricalhinge-class">CategoricalHinge class</a></li> <li><a href="/api/losses/hinge_losses/#hinge-function">hinge function</a></li> <li><a href="/api/losses/hinge_losses/#squared_hinge-function">squared_hinge function</a></li> <li><a href="/api/losses/hinge_losses/#categorical_hinge-function">categorical_hinge function</a></li> </ul> <hr /> <h2 id="base-loss-api">Base Loss API</h2> <p><span style="float:right;"><a href="https://github.com/keras-team/keras/tree/v3.8.0/keras/src/losses/loss.py#L10">[source]</a></span></p> <h3 id="loss-class"><code>Loss</code> class</h3> <div class="codehilite"><pre><span></span><code><span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">Loss</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">reduction</span><span class="o">=</span><span class="s2">"sum_over_batch_size"</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span> </code></pre></div> <p>Loss base class.</p> <p>This is the class to subclass in order to create new custom losses.</p> <p><strong>Arguments</strong></p> <ul> <li><strong>reduction</strong>: Type of reduction to apply to the loss. In almost all cases this should be <code>"sum_over_batch_size"</code>. Supported options are <code>"sum"</code>, <code>"sum_over_batch_size"</code>, <code>"mean"</code>, <code>"mean_with_sample_weight"</code> or <code>None</code>. <code>"sum"</code> sums the loss, <code>"sum_over_batch_size"</code> and <code>"mean"</code> sum the loss and divide by the sample size, and <code>"mean_with_sample_weight"</code> sums the loss and divides by the sum of the sample weights. <code>"none"</code> and <code>None</code> perform no aggregation. Defaults to <code>"sum_over_batch_size"</code>.</li> <li><strong>name</strong>: Optional name for the loss instance.</li> <li><strong>dtype</strong>: The dtype of the loss's computations. Defaults to <code>None</code>, which means using <code>keras.backend.floatx()</code>. <code>keras.backend.floatx()</code> is a <code>"float32"</code> unless set to different value (via <code>keras.backend.set_floatx()</code>). If a <code>keras.DTypePolicy</code> is provided, then the <code>compute_dtype</code> will be utilized.</li> </ul> <p>To be implemented by subclasses:</p> <ul> <li><code>call()</code>: Contains the logic for loss calculation using <code>y_true</code>, <code>y_pred</code>.</li> </ul> <p>Example subclass implementation:</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span><span class="w"> </span><span class="nc">MeanSquaredError</span><span class="p">(</span><span class="n">Loss</span><span class="p">):</span> <span class="k">def</span><span class="w"> </span><span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">y_true</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">):</span> <span class="k">return</span> <span class="n">ops</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">ops</span><span class="o">.</span><span class="n">square</span><span class="p">(</span><span class="n">y_pred</span> <span class="o">-</span> <span class="n">y_true</span><span class="p">),</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> </code></pre></div> <hr /> <hr /> <h2 id="usage-of-losses-with-compile-amp-fit">Usage of losses with <code>compile()</code> & <code>fit()</code></h2> <p>A loss function is one of the two arguments required for compiling a Keras model:</p> <div class="codehilite"><pre><span></span><code><span class="kn">import</span><span class="w"> </span><span class="nn">keras</span> <span class="kn">from</span><span class="w"> </span><span class="nn">keras</span><span class="w"> </span><span class="kn">import</span> <span class="n">layers</span> <span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">()</span> <span class="n">model</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="n">kernel_initializer</span><span class="o">=</span><span class="s1">'uniform'</span><span class="p">,</span> <span class="n">input_shape</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,)))</span> <span class="n">model</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Activation</span><span class="p">(</span><span class="s1">'softmax'</span><span class="p">))</span> <span class="n">loss_fn</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">SparseCategoricalCrossentropy</span><span class="p">()</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">loss</span><span class="o">=</span><span class="n">loss_fn</span><span class="p">,</span> <span class="n">optimizer</span><span class="o">=</span><span class="s1">'adam'</span><span class="p">)</span> </code></pre></div> <p>All built-in loss functions may also be passed via their string identifier:</p> <div class="codehilite"><pre><span></span><code><span class="c1"># pass optimizer by name: default parameters will be used</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">loss</span><span class="o">=</span><span class="s1">'sparse_categorical_crossentropy'</span><span class="p">,</span> <span class="n">optimizer</span><span class="o">=</span><span class="s1">'adam'</span><span class="p">)</span> </code></pre></div> <p>Loss functions are typically created by instantiating a loss class (e.g. <a href="/api/losses/probabilistic_losses#sparsecategoricalcrossentropy-class"><code>keras.losses.SparseCategoricalCrossentropy</code></a>). All losses are also provided as function handles (e.g. <a href="/api/losses/probabilistic_losses#sparsecategoricalcrossentropy-function"><code>keras.losses.sparse_categorical_crossentropy</code></a>).</p> <p>Using classes enables you to pass configuration arguments at instantiation time, e.g.:</p> <div class="codehilite"><pre><span></span><code><span class="n">loss_fn</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">SparseCategoricalCrossentropy</span><span class="p">(</span><span class="n">from_logits</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="standalone-usage-of-losses">Standalone usage of losses</h2> <p>A loss is a callable with arguments <code>loss_fn(y_true, y_pred, sample_weight=None)</code>:</p> <ul> <li><strong>y_true</strong>: Ground truth values, of shape <code>(batch_size, d0, ... dN)</code>. For sparse loss functions, such as sparse categorical crossentropy, the shape should be <code>(batch_size, d0, ... dN-1)</code></li> <li><strong>y_pred</strong>: The predicted values, of shape <code>(batch_size, d0, .. dN)</code>.</li> <li><strong>sample_weight</strong>: Optional <code>sample_weight</code> acts as reduction weighting coefficient for the per-sample losses. If a scalar is provided, then the loss is simply scaled by the given value. If <code>sample_weight</code> is a tensor of size <code>[batch_size]</code>, then the total loss for each sample of the batch is rescaled by the corresponding element in the <code>sample_weight</code> vector. If the shape of <code>sample_weight</code> is <code>(batch_size, d0, ... dN-1)</code> (or can be broadcasted to this shape), then each loss element of <code>y_pred</code> is scaled by the corresponding value of <code>sample_weight</code>. (Note on<code>dN-1</code>: all loss functions reduce by 1 dimension, usually <code>axis=-1</code>.)</li> </ul> <p>By default, loss functions return one scalar loss value for each input sample in the batch dimension, e.g.</p> <div class="codehilite"><pre><span></span><code>>>> from keras import ops >>> keras.losses.mean_squared_error(ops.ones((2, 2,)), ops.zeros((2, 2))) <Array: shape=(2,), dtype=float32, numpy=array([1., 1.], dtype=float32)> </code></pre></div> <p>However, loss class instances feature a <code>reduction</code> constructor argument, which defaults to <code>"sum_over_batch_size"</code> (i.e. average). Allowable values are "sum_over_batch_size", "sum", and "none":</p> <ul> <li>"sum_over_batch_size" means the loss instance will return the average of the per-sample losses in the batch.</li> <li>"sum" means the loss instance will return the sum of the per-sample losses in the batch.</li> <li>"none" means the loss instance will return the full array of per-sample losses.</li> </ul> <div class="codehilite"><pre><span></span><code>>>> loss_fn = keras.losses.MeanSquaredError(reduction='sum_over_batch_size') >>> loss_fn(ops.ones((2, 2,)), ops.zeros((2, 2))) <Array: shape=(), dtype=float32, numpy=1.0> </code></pre></div> <div class="codehilite"><pre><span></span><code>>>> loss_fn = keras.losses.MeanSquaredError(reduction='sum') >>> loss_fn(ops.ones((2, 2,)), ops.zeros((2, 2))) <Array: shape=(), dtype=float32, numpy=2.0> </code></pre></div> <div class="codehilite"><pre><span></span><code>>>> loss_fn = keras.losses.MeanSquaredError(reduction='none') >>> loss_fn(ops.ones((2, 2,)), ops.zeros((2, 2))) <Array: shape=(2,), dtype=float32, numpy=array([1., 1.], dtype=float32)> </code></pre></div> <p>Note that this is an important difference between loss functions like <a href="/api/losses/regression_losses#meansquarederror-function"><code>keras.losses.mean_squared_error</code></a> and default loss class instances like <a href="/api/losses/regression_losses#meansquarederror-class"><code>keras.losses.MeanSquaredError</code></a>: the function version does not perform reduction, but by default the class instance does.</p> <div class="codehilite"><pre><span></span><code>>>> loss_fn = keras.losses.mean_squared_error >>> loss_fn(ops.ones((2, 2,)), ops.zeros((2, 2))) <Array: shape=(2,), dtype=float32, numpy=array([1., 1.], dtype=float32)> </code></pre></div> <div class="codehilite"><pre><span></span><code>>>> loss_fn = keras.losses.MeanSquaredError() >>> loss_fn(ops.ones((2, 2,)), ops.zeros((2, 2))) <Array: shape=(), dtype=float32, numpy=1.0> </code></pre></div> <p>When using <code>fit()</code>, this difference is irrelevant since reduction is handled by the framework.</p> <p>Here's how you would use a loss class instance as part of a simple training loop:</p> <div class="codehilite"><pre><span></span><code><span class="n">loss_fn</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">CategoricalCrossentropy</span><span class="p">(</span><span class="n">from_logits</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="n">optimizer</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">Adam</span><span class="p">()</span> <span class="c1"># Iterate over the batches of a dataset.</span> <span class="k">for</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span> <span class="ow">in</span> <span class="n">dataset</span><span class="p">:</span> <span class="k">with</span> <span class="n">tf</span><span class="o">.</span><span class="n">GradientTape</span><span class="p">()</span> <span class="k">as</span> <span class="n">tape</span><span class="p">:</span> <span class="n">logits</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># Compute the loss value for this batch.</span> <span class="n">loss_value</span> <span class="o">=</span> <span class="n">loss_fn</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="n">logits</span><span class="p">)</span> <span class="c1"># Update the weights of the model to minimize the loss value.</span> <span class="n">gradients</span> <span class="o">=</span> <span class="n">tape</span><span class="o">.</span><span class="n">gradient</span><span class="p">(</span><span class="n">loss_value</span><span class="p">,</span> <span class="n">model</span><span class="o">.</span><span class="n">trainable_weights</span><span class="p">)</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="n">gradients</span><span class="p">,</span> <span class="n">model</span><span class="o">.</span><span class="n">trainable_weights</span><span class="p">))</span> </code></pre></div> <hr /> <h2 id="creating-custom-losses">Creating custom losses</h2> <p>Any callable with the signature <code>loss_fn(y_true, y_pred)</code> that returns an array of losses (one of sample in the input batch) can be passed to <code>compile()</code> as a loss. Note that sample weighting is automatically supported for any such loss.</p> <p>Here's a simple example:</p> <div class="codehilite"><pre><span></span><code><span class="kn">from</span><span class="w"> </span><span class="nn">keras</span><span class="w"> </span><span class="kn">import</span> <span class="n">ops</span> <span class="k">def</span><span class="w"> </span><span class="nf">my_loss_fn</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">):</span> <span class="n">squared_difference</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">square</span><span class="p">(</span><span class="n">y_true</span> <span class="o">-</span> <span class="n">y_pred</span><span class="p">)</span> <span class="k">return</span> <span class="n">ops</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">squared_difference</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># Note the `axis=-1`</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">optimizer</span><span class="o">=</span><span class="s1">'adam'</span><span class="p">,</span> <span class="n">loss</span><span class="o">=</span><span class="n">my_loss_fn</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="the-addloss-api">The <code>add_loss()</code> API</h2> <p>Loss functions applied to the output of a model aren't the only way to create losses.</p> <p>When writing the <code>call</code> method of a custom layer or a subclassed model, you may want to compute scalar quantities that you want to minimize during training (e.g. regularization losses). You can use the <code>add_loss()</code> layer method to keep track of such loss terms.</p> <p>Here's an example of a layer that adds a sparsity regularization loss based on the L2 norm of the inputs:</p> <div class="codehilite"><pre><span></span><code><span class="kn">from</span><span class="w"> </span><span class="nn">keras</span><span class="w"> </span><span class="kn">import</span> <span class="n">ops</span> <span class="k">class</span><span class="w"> </span><span class="nc">MyActivityRegularizer</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Layer that creates an activity sparsity regularization loss."""</span> <span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">rate</span><span class="o">=</span><span class="mf">1e-2</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">rate</span> <span class="o">=</span> <span class="n">rate</span> <span class="k">def</span><span class="w"> </span><span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="c1"># We use `add_loss` to create a regularization loss</span> <span class="c1"># that depends on the inputs.</span> <span class="bp">self</span><span class="o">.</span><span class="n">add_loss</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">rate</span> <span class="o">*</span> <span class="n">ops</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">ops</span><span class="o">.</span><span class="n">square</span><span class="p">(</span><span class="n">inputs</span><span class="p">)))</span> <span class="k">return</span> <span class="n">inputs</span> </code></pre></div> <p>Loss values added via <code>add_loss</code> can be retrieved in the <code>.losses</code> list property of any <code>Layer</code> or <code>Model</code> (they are recursively retrieved from every underlying layer):</p> <div class="codehilite"><pre><span></span><code><span class="kn">from</span><span class="w"> </span><span class="nn">keras</span><span class="w"> </span><span class="kn">import</span> <span class="n">layers</span> <span class="kn">from</span><span class="w"> </span><span class="nn">keras</span><span class="w"> </span><span class="kn">import</span> <span class="n">ops</span> <span class="k">class</span><span class="w"> </span><span class="nc">SparseMLP</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Stack of Linear layers with a sparsity regularization loss."""</span> <span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">output_dim</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense_1</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="n">ops</span><span class="o">.</span><span class="n">relu</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">regularization</span> <span class="o">=</span> <span class="n">MyActivityRegularizer</span><span class="p">(</span><span class="mf">1e-2</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense_2</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">output_dim</span><span class="p">)</span> <span class="k">def</span><span class="w"> </span><span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense_1</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">regularization</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense_2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">mlp</span> <span class="o">=</span> <span class="n">SparseMLP</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> <span class="n">y</span> <span class="o">=</span> <span class="n">mlp</span><span class="p">(</span><span class="n">ops</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">10</span><span class="p">,</span> <span class="mi">10</span><span class="p">)))</span> <span class="nb">print</span><span class="p">(</span><span class="n">mlp</span><span class="o">.</span><span class="n">losses</span><span class="p">)</span> <span class="c1"># List containing one float32 scalar</span> </code></pre></div> <p>These losses are cleared by the top-level layer at the start of each forward pass – they don't accumulate. So <code>layer.losses</code> always contain only the losses created during the last forward pass. You would typically use these losses by summing them before computing your gradients when writing a training loop.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Losses correspond to the *last* forward pass.</span> <span class="n">mlp</span> <span class="o">=</span> <span class="n">SparseMLP</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> <span class="n">mlp</span><span class="p">(</span><span class="n">ops</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">10</span><span class="p">,</span> <span class="mi">10</span><span class="p">)))</span> <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">mlp</span><span class="o">.</span><span class="n">losses</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span> <span class="n">mlp</span><span class="p">(</span><span class="n">ops</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">10</span><span class="p">,</span> <span class="mi">10</span><span class="p">)))</span> <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">mlp</span><span class="o">.</span><span class="n">losses</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span> <span class="c1"># No accumulation.</span> </code></pre></div> <p>When using <code>model.fit()</code>, such loss terms are handled automatically.</p> <p>When writing a custom training loop, you should retrieve these terms by hand from <code>model.losses</code>, like this:</p> <div class="codehilite"><pre><span></span><code><span class="n">loss_fn</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">CategoricalCrossentropy</span><span class="p">(</span><span class="n">from_logits</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="n">optimizer</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">Adam</span><span class="p">()</span> <span class="c1"># Iterate over the batches of a dataset.</span> <span class="k">for</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span> <span class="ow">in</span> <span class="n">dataset</span><span class="p">:</span> <span class="k">with</span> <span class="n">tf</span><span class="o">.</span><span class="n">GradientTape</span><span class="p">()</span> <span class="k">as</span> <span class="n">tape</span><span class="p">:</span> <span class="c1"># Forward pass.</span> <span class="n">logits</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># Loss value for this batch.</span> <span class="n">loss_value</span> <span class="o">=</span> <span class="n">loss_fn</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="n">logits</span><span class="p">)</span> <span class="c1"># Add extra loss terms to the loss value.</span> <span class="n">loss_value</span> <span class="o">+=</span> <span class="nb">sum</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">losses</span><span class="p">)</span> <span class="c1"># Update the weights of the model to minimize the loss value.</span> <span class="n">gradients</span> <span class="o">=</span> <span class="n">tape</span><span class="o">.</span><span class="n">gradient</span><span class="p">(</span><span class="n">loss_value</span><span class="p">,</span> <span class="n">model</span><span class="o">.</span><span class="n">trainable_weights</span><span class="p">)</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="n">gradients</span><span class="p">,</span> <span class="n">model</span><span class="o">.</span><span class="n">trainable_weights</span><span class="p">))</span> </code></pre></div> <p>See <a href="/api/layers/base_layer/#add_loss-method">the <code>add_loss()</code> documentation</a> for more details.</p> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#losses'>Losses</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#available-losses'>Available losses</a> </div> <div class='k-outline-depth-3'> <a href='#probabilistic-losses'>Probabilistic losses</a> </div> <div class='k-outline-depth-3'> <a href='#regression-losses'>Regression losses</a> </div> <div class='k-outline-depth-3'> <a href='#hinge-losses-for-maximummargin-classification'>Hinge losses for "maximum-margin" classification</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#base-loss-api'>Base Loss API</a> </div> <div class='k-outline-depth-3'> <a href='#loss-class'><code>Loss</code> class</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#usage-of-losses-with-compile-amp-fit'>Usage of losses with <code>compile()</code> & <code>fit()</code></a> </div> <div class='k-outline-depth-2'> ◆ <a href='#standalone-usage-of-losses'>Standalone usage of losses</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#creating-custom-losses'>Creating custom losses</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#the-addloss-api'>The <code>add_loss()</code> API</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>