CINXE.COM
Metrics
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/api/metrics/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Metrics"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Metrics"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Metrics</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-link active" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-sublink" href="/api/models/">Models API</a> <a class="nav-sublink" href="/api/layers/">Layers API</a> <a class="nav-sublink" href="/api/callbacks/">Callbacks API</a> <a class="nav-sublink" href="/api/ops/">Ops API</a> <a class="nav-sublink" href="/api/optimizers/">Optimizers</a> <a class="nav-sublink active" href="/api/metrics/">Metrics</a> <a class="nav-sublink2" href="/api/metrics/base_metric/">Base Metric class</a> <a class="nav-sublink2" href="/api/metrics/accuracy_metrics/">Accuracy metrics</a> <a class="nav-sublink2" href="/api/metrics/probabilistic_metrics/">Probabilistic metrics</a> <a class="nav-sublink2" href="/api/metrics/regression_metrics/">Regression metrics</a> <a class="nav-sublink2" href="/api/metrics/classification_metrics/">Classification metrics based on True/False positives & negatives</a> <a class="nav-sublink2" href="/api/metrics/segmentation_metrics/">Image segmentation metrics</a> <a class="nav-sublink2" href="/api/metrics/hinge_metrics/">Hinge metrics for "maximum-margin" classification</a> <a class="nav-sublink2" href="/api/metrics/metrics_wrappers/">Metric wrappers and reduction metrics</a> <a class="nav-sublink" href="/api/losses/">Losses</a> <a class="nav-sublink" href="/api/data_loading/">Data loading</a> <a class="nav-sublink" href="/api/datasets/">Built-in small datasets</a> <a class="nav-sublink" href="/api/applications/">Keras Applications</a> <a class="nav-sublink" href="/api/mixed_precision/">Mixed precision</a> <a class="nav-sublink" href="/api/distribution/">Multi-device distribution</a> <a class="nav-sublink" href="/api/random/">RNG API</a> <a class="nav-sublink" href="/api/utils/">Utilities</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparam Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/api/'>Keras 3 API documentation</a> / Metrics </div> <div class='k-content'> <h1 id="metrics">Metrics</h1> <p>A metric is a function that is used to judge the performance of your model.</p> <p>Metric functions are similar to loss functions, except that the results from evaluating a metric are not used when training the model. Note that you may use any loss function as a metric.</p> <h2 id="available-metrics">Available metrics</h2> <h3 id="base-metric-class"><a href="/api/metrics/base_metric/">Base Metric class</a></h3> <ul> <li><a href="/api/metrics/base_metric/#metric-class">Metric class</a></li> </ul> <h3 id="accuracy-metrics"><a href="/api/metrics/accuracy_metrics/">Accuracy metrics</a></h3> <ul> <li><a href="/api/metrics/accuracy_metrics/#accuracy-class">Accuracy class</a></li> <li><a href="/api/metrics/accuracy_metrics/#binaryaccuracy-class">BinaryAccuracy class</a></li> <li><a href="/api/metrics/accuracy_metrics/#categoricalaccuracy-class">CategoricalAccuracy class</a></li> <li><a href="/api/metrics/accuracy_metrics/#sparsecategoricalaccuracy-class">SparseCategoricalAccuracy class</a></li> <li><a href="/api/metrics/accuracy_metrics/#topkcategoricalaccuracy-class">TopKCategoricalAccuracy class</a></li> <li><a href="/api/metrics/accuracy_metrics/#sparsetopkcategoricalaccuracy-class">SparseTopKCategoricalAccuracy class</a></li> </ul> <h3 id="probabilistic-metrics"><a href="/api/metrics/probabilistic_metrics/">Probabilistic metrics</a></h3> <ul> <li><a href="/api/metrics/probabilistic_metrics/#binarycrossentropy-class">BinaryCrossentropy class</a></li> <li><a href="/api/metrics/probabilistic_metrics/#categoricalcrossentropy-class">CategoricalCrossentropy class</a></li> <li><a href="/api/metrics/probabilistic_metrics/#sparsecategoricalcrossentropy-class">SparseCategoricalCrossentropy class</a></li> <li><a href="/api/metrics/probabilistic_metrics/#kldivergence-class">KLDivergence class</a></li> <li><a href="/api/metrics/probabilistic_metrics/#poisson-class">Poisson class</a></li> </ul> <h3 id="regression-metrics"><a href="/api/metrics/regression_metrics/">Regression metrics</a></h3> <ul> <li><a href="/api/metrics/regression_metrics/#meansquarederror-class">MeanSquaredError class</a></li> <li><a href="/api/metrics/regression_metrics/#rootmeansquarederror-class">RootMeanSquaredError class</a></li> <li><a href="/api/metrics/regression_metrics/#meanabsoluteerror-class">MeanAbsoluteError class</a></li> <li><a href="/api/metrics/regression_metrics/#meanabsolutepercentageerror-class">MeanAbsolutePercentageError class</a></li> <li><a href="/api/metrics/regression_metrics/#meansquaredlogarithmicerror-class">MeanSquaredLogarithmicError class</a></li> <li><a href="/api/metrics/regression_metrics/#cosinesimilarity-class">CosineSimilarity class</a></li> <li><a href="/api/metrics/regression_metrics/#logcosherror-class">LogCoshError class</a></li> <li><a href="/api/metrics/regression_metrics/#r2score-class">R2Score class</a></li> </ul> <h3 id="classification-metrics-based-on-truefalse-positives-amp-negatives"><a href="/api/metrics/classification_metrics/">Classification metrics based on True/False positives & negatives</a></h3> <ul> <li><a href="/api/metrics/classification_metrics/#auc-class">AUC class</a></li> <li><a href="/api/metrics/classification_metrics/#precision-class">Precision class</a></li> <li><a href="/api/metrics/classification_metrics/#recall-class">Recall class</a></li> <li><a href="/api/metrics/classification_metrics/#truepositives-class">TruePositives class</a></li> <li><a href="/api/metrics/classification_metrics/#truenegatives-class">TrueNegatives class</a></li> <li><a href="/api/metrics/classification_metrics/#falsepositives-class">FalsePositives class</a></li> <li><a href="/api/metrics/classification_metrics/#falsenegatives-class">FalseNegatives class</a></li> <li><a href="/api/metrics/classification_metrics/#precisionatrecall-class">PrecisionAtRecall class</a></li> <li><a href="/api/metrics/classification_metrics/#recallatprecision-class">RecallAtPrecision class</a></li> <li><a href="/api/metrics/classification_metrics/#sensitivityatspecificity-class">SensitivityAtSpecificity class</a></li> <li><a href="/api/metrics/classification_metrics/#specificityatsensitivity-class">SpecificityAtSensitivity class</a></li> <li><a href="/api/metrics/classification_metrics/#f1score-class">F1Score class</a></li> <li><a href="/api/metrics/classification_metrics/#fbetascore-class">FBetaScore class</a></li> </ul> <h3 id="image-segmentation-metrics"><a href="/api/metrics/segmentation_metrics/">Image segmentation metrics</a></h3> <ul> <li><a href="/api/metrics/segmentation_metrics/#iou-class">IoU class</a></li> <li><a href="/api/metrics/segmentation_metrics/#binaryiou-class">BinaryIoU class</a></li> <li><a href="/api/metrics/segmentation_metrics/#onehotiou-class">OneHotIoU class</a></li> <li><a href="/api/metrics/segmentation_metrics/#onehotmeaniou-class">OneHotMeanIoU class</a></li> <li><a href="/api/metrics/segmentation_metrics/#meaniou-class">MeanIoU class</a></li> </ul> <h3 id="hinge-metrics-for-maximummargin-classification"><a href="/api/metrics/hinge_metrics/">Hinge metrics for "maximum-margin" classification</a></h3> <ul> <li><a href="/api/metrics/hinge_metrics/#hinge-class">Hinge class</a></li> <li><a href="/api/metrics/hinge_metrics/#squaredhinge-class">SquaredHinge class</a></li> <li><a href="/api/metrics/hinge_metrics/#categoricalhinge-class">CategoricalHinge class</a></li> </ul> <h3 id="metric-wrappers-and-reduction-metrics"><a href="/api/metrics/metrics_wrappers/">Metric wrappers and reduction metrics</a></h3> <ul> <li><a href="/api/metrics/metrics_wrappers/#meanmetricwrapper-class">MeanMetricWrapper class</a></li> <li><a href="/api/metrics/metrics_wrappers/#mean-class">Mean class</a></li> <li><a href="/api/metrics/metrics_wrappers/#sum-class">Sum class</a></li> </ul> <hr /> <h2 id="usage-with-compile-amp-fit">Usage with <code>compile()</code> & <code>fit()</code></h2> <p>The <code>compile()</code> method takes a <code>metrics</code> argument, which is a list of metrics:</p> <div class="codehilite"><pre><span></span><code><span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">optimizer</span><span class="o">=</span><span class="s1">'adam'</span><span class="p">,</span> <span class="n">loss</span><span class="o">=</span><span class="s1">'mean_squared_error'</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span> <span class="n">metrics</span><span class="o">.</span><span class="n">MeanSquaredError</span><span class="p">(),</span> <span class="n">metrics</span><span class="o">.</span><span class="n">AUC</span><span class="p">(),</span> <span class="p">]</span> <span class="p">)</span> </code></pre></div> <p>Metric values are displayed during <code>fit()</code> and logged to the <code>History</code> object returned by <code>fit()</code>. They are also returned by <code>model.evaluate()</code>.</p> <p>Note that the best way to monitor your metrics during training is via <a href="/api/callbacks/tensorboard">TensorBoard</a>.</p> <p>To track metrics under a specific name, you can pass the <code>name</code> argument to the metric constructor:</p> <div class="codehilite"><pre><span></span><code><span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">optimizer</span><span class="o">=</span><span class="s1">'adam'</span><span class="p">,</span> <span class="n">loss</span><span class="o">=</span><span class="s1">'mean_squared_error'</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span> <span class="n">metrics</span><span class="o">.</span><span class="n">MeanSquaredError</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">'my_mse'</span><span class="p">),</span> <span class="n">metrics</span><span class="o">.</span><span class="n">AUC</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">'my_auc'</span><span class="p">),</span> <span class="p">]</span> <span class="p">)</span> </code></pre></div> <p>All built-in metrics may also be passed via their string identifier (in this case, default constructor argument values are used, including a default metric name):</p> <div class="codehilite"><pre><span></span><code><span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">optimizer</span><span class="o">=</span><span class="s1">'adam'</span><span class="p">,</span> <span class="n">loss</span><span class="o">=</span><span class="s1">'mean_squared_error'</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span> <span class="s1">'MeanSquaredError'</span><span class="p">,</span> <span class="s1">'AUC'</span><span class="p">,</span> <span class="p">]</span> <span class="p">)</span> </code></pre></div> <hr /> <h2 id="standalone-usage">Standalone usage</h2> <p>Unlike losses, metrics are stateful. You update their state using the <code>update_state()</code> method, and you query the scalar metric result using the <code>result()</code> method:</p> <div class="codehilite"><pre><span></span><code><span class="n">m</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">AUC</span><span class="p">()</span> <span class="n">m</span><span class="o">.</span><span class="n">update_state</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">])</span> <span class="nb">print</span><span class="p">(</span><span class="s1">'Intermediate result:'</span><span class="p">,</span> <span class="nb">float</span><span class="p">(</span><span class="n">m</span><span class="o">.</span><span class="n">result</span><span class="p">()))</span> <span class="n">m</span><span class="o">.</span><span class="n">update_state</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">])</span> <span class="nb">print</span><span class="p">(</span><span class="s1">'Final result:'</span><span class="p">,</span> <span class="nb">float</span><span class="p">(</span><span class="n">m</span><span class="o">.</span><span class="n">result</span><span class="p">()))</span> </code></pre></div> <p>The internal state can be cleared via <code>metric.reset_states()</code>.</p> <p>Here's how you would use a metric as part of a simple custom training loop:</p> <div class="codehilite"><pre><span></span><code><span class="n">accuracy</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">CategoricalAccuracy</span><span class="p">()</span> <span class="n">loss_fn</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">CategoricalCrossentropy</span><span class="p">(</span><span class="n">from_logits</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="n">optimizer</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">Adam</span><span class="p">()</span> <span class="c1"># Iterate over the batches of a dataset.</span> <span class="k">for</span> <span class="n">step</span><span class="p">,</span> <span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">dataset</span><span class="p">):</span> <span class="k">with</span> <span class="n">tf</span><span class="o">.</span><span class="n">GradientTape</span><span class="p">()</span> <span class="k">as</span> <span class="n">tape</span><span class="p">:</span> <span class="n">logits</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># Compute the loss value for this batch.</span> <span class="n">loss_value</span> <span class="o">=</span> <span class="n">loss_fn</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="n">logits</span><span class="p">)</span> <span class="c1"># Update the state of the `accuracy` metric.</span> <span class="n">accuracy</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="n">logits</span><span class="p">)</span> <span class="c1"># Update the weights of the model to minimize the loss value.</span> <span class="n">gradients</span> <span class="o">=</span> <span class="n">tape</span><span class="o">.</span><span class="n">gradient</span><span class="p">(</span><span class="n">loss_value</span><span class="p">,</span> <span class="n">model</span><span class="o">.</span><span class="n">trainable_weights</span><span class="p">)</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="n">gradients</span><span class="p">,</span> <span class="n">model</span><span class="o">.</span><span class="n">trainable_weights</span><span class="p">))</span> <span class="c1"># Logging the current accuracy value so far.</span> <span class="k">if</span> <span class="n">step</span> <span class="o">%</span> <span class="mi">100</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span> <span class="nb">print</span><span class="p">(</span><span class="s1">'Step:'</span><span class="p">,</span> <span class="n">step</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s1">'Total running accuracy so far: </span><span class="si">%.3f</span><span class="s1">'</span> <span class="o">%</span> <span class="n">accuracy</span><span class="o">.</span><span class="n">result</span><span class="p">())</span> </code></pre></div> <hr /> <h2 id="creating-custom-metrics">Creating custom metrics</h2> <h3 id="as-simple-callables-stateless">As simple callables (stateless)</h3> <p>Much like loss functions, any callable with signature <code>metric_fn(y_true, y_pred)</code> that returns an array of losses (one of sample in the input batch) can be passed to <code>compile()</code> as a metric. Note that sample weighting is automatically supported for any such metric.</p> <p>Here's a simple example:</p> <div class="codehilite"><pre><span></span><code><span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">ops</span> <span class="k">def</span> <span class="nf">my_metric_fn</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">):</span> <span class="n">squared_difference</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">square</span><span class="p">(</span><span class="n">y_true</span> <span class="o">-</span> <span class="n">y_pred</span><span class="p">)</span> <span class="k">return</span> <span class="n">ops</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">squared_difference</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># Note the `axis=-1`</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">optimizer</span><span class="o">=</span><span class="s1">'adam'</span><span class="p">,</span> <span class="n">loss</span><span class="o">=</span><span class="s1">'mean_squared_error'</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="n">my_metric_fn</span><span class="p">])</span> </code></pre></div> <p>In this case, the scalar metric value you are tracking during training and evaluation is the average of the per-batch metric values for all batches see during a given epoch (or during a given call to <code>model.evaluate()</code>).</p> <h3 id="as-subclasses-of-metric-stateful">As subclasses of <code>Metric</code> (stateful)</h3> <p>Not all metrics can be expressed via stateless callables, because metrics are evaluated for each batch during training and evaluation, but in some cases the average of the per-batch values is not what you are interested in.</p> <p>Let's say that you want to compute AUC over a given evaluation dataset: the average of the per-batch AUC values isn't the same as the AUC over the entire dataset.</p> <p>For such metrics, you're going to want to subclass the <code>Metric</code> class, which can maintain a state across batches. It's easy:</p> <ul> <li>Create the state variables in <code>__init__</code></li> <li>Update the variables given <code>y_true</code> and <code>y_pred</code> in <code>update_state()</code></li> <li>Return the scalar metric result in <code>result()</code></li> <li>Clear the state in <code>reset_states()</code></li> </ul> <p>Here's a simple example computing binary true positives:</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">BinaryTruePositives</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">Metric</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">'binary_true_positives'</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">true_positives</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">add_weight</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">'tp'</span><span class="p">,</span> <span class="n">initializer</span><span class="o">=</span><span class="s1">'zeros'</span><span class="p">)</span> <span class="k">def</span> <span class="nf">update_state</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">y_true</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">,</span> <span class="n">sample_weight</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="n">y_true</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="s2">"bool"</span><span class="p">)</span> <span class="n">y_pred</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">y_pred</span><span class="p">,</span> <span class="s2">"bool"</span><span class="p">)</span> <span class="n">values</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">logical_and</span><span class="p">(</span><span class="n">ops</span><span class="o">.</span><span class="n">equal</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="kc">True</span><span class="p">),</span> <span class="n">ops</span><span class="o">.</span><span class="n">equal</span><span class="p">(</span><span class="n">y_pred</span><span class="p">,</span> <span class="kc">True</span><span class="p">))</span> <span class="n">values</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">values</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span> <span class="k">if</span> <span class="n">sample_weight</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> <span class="n">sample_weight</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">sample_weight</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span> <span class="n">values</span> <span class="o">=</span> <span class="n">values</span> <span class="o">*</span> <span class="n">sample_weight</span> <span class="bp">self</span><span class="o">.</span><span class="n">true_positives</span><span class="o">.</span><span class="n">assign_add</span><span class="p">(</span><span class="n">ops</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">values</span><span class="p">))</span> <span class="k">def</span> <span class="nf">result</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">true_positives</span> <span class="k">def</span> <span class="nf">reset_state</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="bp">self</span><span class="o">.</span><span class="n">true_positives</span><span class="o">.</span><span class="n">assign</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="n">m</span> <span class="o">=</span> <span class="n">BinaryTruePositives</span><span class="p">()</span> <span class="n">m</span><span class="o">.</span><span class="n">update_state</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">])</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s1">'Intermediate result: </span><span class="si">{</span><span class="n">m</span><span class="o">.</span><span class="n">result</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span> <span class="n">m</span><span class="o">.</span><span class="n">update_state</span><span class="p">([</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">])</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s1">'Intermediate result: </span><span class="si">{</span><span class="n">m</span><span class="o">.</span><span class="n">result</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span> </code></pre></div> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#metrics'>Metrics</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#available-metrics'>Available metrics</a> </div> <div class='k-outline-depth-3'> <a href='#base-metric-class'>Base Metric class</a> </div> <div class='k-outline-depth-3'> <a href='#accuracy-metrics'>Accuracy metrics</a> </div> <div class='k-outline-depth-3'> <a href='#probabilistic-metrics'>Probabilistic metrics</a> </div> <div class='k-outline-depth-3'> <a href='#regression-metrics'>Regression metrics</a> </div> <div class='k-outline-depth-3'> <a href='#classification-metrics-based-on-truefalse-positives-amp-negatives'>Classification metrics based on True/False positives & negatives</a> </div> <div class='k-outline-depth-3'> <a href='#image-segmentation-metrics'>Image segmentation metrics</a> </div> <div class='k-outline-depth-3'> <a href='#hinge-metrics-for-maximummargin-classification'>Hinge metrics for "maximum-margin" classification</a> </div> <div class='k-outline-depth-3'> <a href='#metric-wrappers-and-reduction-metrics'>Metric wrappers and reduction metrics</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#usage-with-compile-amp-fit'>Usage with <code>compile()</code> & <code>fit()</code></a> </div> <div class='k-outline-depth-2'> ◆ <a href='#standalone-usage'>Standalone usage</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#creating-custom-metrics'>Creating custom metrics</a> </div> <div class='k-outline-depth-3'> <a href='#as-simple-callables-stateless'>As simple callables (stateless)</a> </div> <div class='k-outline-depth-3'> <a href='#as-subclasses-of-metric-stateful'>As subclasses of <code>Metric</code> (stateful)</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>