CINXE.COM

Introduction to Keras for engineers

<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/getting_started/intro_to_keras_for_engineers/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Introduction to Keras for engineers"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Introduction to Keras for engineers"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Introduction to Keras for engineers</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link active" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-sublink active" href="/getting_started/intro_to_keras_for_engineers/">Introduction to Keras for engineers</a> <a class="nav-sublink" href="/getting_started/benchmarks/">Keras 3 benchmarks</a> <a class="nav-sublink" href="/getting_started/ecosystem/">The Keras ecosystem</a> <a class="nav-sublink" href="/getting_started/faq/">Frequently Asked Questions</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparameter Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> <a class="nav-link" href="/keras_cv/" role="tab" aria-selected="">KerasCV: Computer Vision Workflows</a> <a class="nav-link" href="/keras_nlp/" role="tab" aria-selected="">KerasNLP: Natural Language Workflows</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/getting_started/'>Getting started</a> / Introduction to Keras for engineers </div> <div class='k-content'> <h1 id="introduction-to-keras-for-engineers">Introduction to Keras for engineers</h1> <p><strong>Author:</strong> <a href="https://twitter.com/fchollet">fchollet</a><br> <strong>Date created:</strong> 2023/07/10<br> <strong>Last modified:</strong> 2023/07/10<br> <strong>Description:</strong> First contact with Keras 3.</p> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/guides/ipynb/intro_to_keras_for_engineers.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/guides/intro_to_keras_for_engineers.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="introduction">Introduction</h2> <p>Keras 3 is a deep learning framework works with TensorFlow, JAX, and PyTorch interchangeably. This notebook will walk you through key Keras 3 workflows.</p> <hr /> <h2 id="setup">Setup</h2> <p>We're going to be using the JAX backend here &ndash; but you can edit the string below to <code>"tensorflow"</code> or <code>"torch"</code> and hit "Restart runtime", and the whole notebook will run just the same! This entire guide is backend-agnostic.</p> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> <span class="kn">import</span> <span class="nn">os</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s2">&quot;KERAS_BACKEND&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="s2">&quot;jax&quot;</span> <span class="c1"># Note that Keras should only be imported after the backend</span> <span class="c1"># has been configured. The backend cannot be changed once the</span> <span class="c1"># package is imported.</span> <span class="kn">import</span> <span class="nn">keras</span> </code></pre></div> <hr /> <h2 id="a-first-example-a-mnist-convnet">A first example: A MNIST convnet</h2> <p>Let's start with the Hello World of ML: training a convnet to classify MNIST digits.</p> <p>Here's the data:</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Load the data and split it between train and test sets</span> <span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">),</span> <span class="p">(</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">)</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">datasets</span><span class="o">.</span><span class="n">mnist</span><span class="o">.</span><span class="n">load_data</span><span class="p">()</span> <span class="c1"># Scale images to the [0, 1] range</span> <span class="n">x_train</span> <span class="o">=</span> <span class="n">x_train</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">&quot;float32&quot;</span><span class="p">)</span> <span class="o">/</span> <span class="mi">255</span> <span class="n">x_test</span> <span class="o">=</span> <span class="n">x_test</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">&quot;float32&quot;</span><span class="p">)</span> <span class="o">/</span> <span class="mi">255</span> <span class="c1"># Make sure images have shape (28, 28, 1)</span> <span class="n">x_train</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="n">x_test</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">x_test</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;x_train shape:&quot;</span><span class="p">,</span> <span class="n">x_train</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;y_train shape:&quot;</span><span class="p">,</span> <span class="n">y_train</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="n">x_train</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="s2">&quot;train samples&quot;</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="n">x_test</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="s2">&quot;test samples&quot;</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>x_train shape: (60000, 28, 28, 1) y_train shape: (60000,) 60000 train samples 10000 test samples </code></pre></div> </div> <p>Here's our model.</p> <p>Different model-building options that Keras offers include:</p> <ul> <li><a href="https://keras.io/guides/sequential_model/">The Sequential API</a> (what we use below)</li> <li><a href="https://keras.io/guides/functional_api/">The Functional API</a> (most typical)</li> <li><a href="https://keras.io/guides/making_new_layers_and_models_via_subclassing/">Writing your own models yourself via subclassing</a> (for advanced use cases)</li> </ul> <div class="codehilite"><pre><span></span><code><span class="c1"># Model parameters</span> <span class="n">num_classes</span> <span class="o">=</span> <span class="mi">10</span> <span class="n">input_shape</span> <span class="o">=</span> <span class="p">(</span><span class="mi">28</span><span class="p">,</span> <span class="mi">28</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span> <span class="p">[</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="n">input_shape</span><span class="p">),</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">activation</span><span class="o">=</span><span class="s2">&quot;relu&quot;</span><span class="p">),</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">activation</span><span class="o">=</span><span class="s2">&quot;relu&quot;</span><span class="p">),</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">MaxPooling2D</span><span class="p">(</span><span class="n">pool_size</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">)),</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">activation</span><span class="o">=</span><span class="s2">&quot;relu&quot;</span><span class="p">),</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">activation</span><span class="o">=</span><span class="s2">&quot;relu&quot;</span><span class="p">),</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">GlobalAveragePooling2D</span><span class="p">(),</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="mf">0.5</span><span class="p">),</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">num_classes</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">&quot;softmax&quot;</span><span class="p">),</span> <span class="p">]</span> <span class="p">)</span> </code></pre></div> <p>Here's our model summary:</p> <div class="codehilite"><pre><span></span><code><span class="n">model</span><span class="o">.</span><span class="n">summary</span><span class="p">()</span> </code></pre></div> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold">Model: "sequential"</span> </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ ┃<span style="font-weight: bold"> Layer (type) </span>┃<span style="font-weight: bold"> Output Shape </span>┃<span style="font-weight: bold"> Param # </span>┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ │ conv2d (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv2D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">26</span>, <span style="color: #00af00; text-decoration-color: #00af00">26</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">640</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_1 (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv2D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">24</span>, <span style="color: #00af00; text-decoration-color: #00af00">24</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">36,928</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ max_pooling2d (<span style="color: #0087ff; text-decoration-color: #0087ff">MaxPooling2D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">12</span>, <span style="color: #00af00; text-decoration-color: #00af00">12</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_2 (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv2D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">10</span>, <span style="color: #00af00; text-decoration-color: #00af00">10</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">73,856</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_3 (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv2D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">8</span>, <span style="color: #00af00; text-decoration-color: #00af00">8</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">147,584</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ global_average_pooling2d │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">GlobalAveragePooling2D</span>) │ │ │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ dropout (<span style="color: #0087ff; text-decoration-color: #0087ff">Dropout</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ dense (<span style="color: #0087ff; text-decoration-color: #0087ff">Dense</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">10</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">1,290</span> │ └─────────────────────────────────┴───────────────────────────┴────────────┘ </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Total params: </span><span style="color: #00af00; text-decoration-color: #00af00">260,298</span> (1016.79 KB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">260,298</span> (1016.79 KB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Non-trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">0</span> (0.00 B) </pre> <p>We use the <code>compile()</code> method to specify the optimizer, loss function, and the metrics to monitor. Note that with the JAX and TensorFlow backends, XLA compilation is turned on by default.</p> <div class="codehilite"><pre><span></span><code><span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">loss</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">SparseCategoricalCrossentropy</span><span class="p">(),</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">learning_rate</span><span class="o">=</span><span class="mf">1e-3</span><span class="p">),</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span> <span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">SparseCategoricalAccuracy</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">&quot;acc&quot;</span><span class="p">),</span> <span class="p">],</span> <span class="p">)</span> </code></pre></div> <p>Let's train and evaluate the model. We'll set aside a validation split of 15% of the data during training to monitor generalization on unseen data.</p> <div class="codehilite"><pre><span></span><code><span class="n">batch_size</span> <span class="o">=</span> <span class="mi">128</span> <span class="n">epochs</span> <span class="o">=</span> <span class="mi">20</span> <span class="n">callbacks</span> <span class="o">=</span> <span class="p">[</span> <span class="n">keras</span><span class="o">.</span><span class="n">callbacks</span><span class="o">.</span><span class="n">ModelCheckpoint</span><span class="p">(</span><span class="n">filepath</span><span class="o">=</span><span class="s2">&quot;model_at_epoch_</span><span class="si">{epoch}</span><span class="s2">.keras&quot;</span><span class="p">),</span> <span class="n">keras</span><span class="o">.</span><span class="n">callbacks</span><span class="o">.</span><span class="n">EarlyStopping</span><span class="p">(</span><span class="n">monitor</span><span class="o">=</span><span class="s2">&quot;val_loss&quot;</span><span class="p">,</span> <span class="n">patience</span><span class="o">=</span><span class="mi">2</span><span class="p">),</span> <span class="p">]</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span> <span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="n">epochs</span><span class="p">,</span> <span class="n">validation_split</span><span class="o">=</span><span class="mf">0.15</span><span class="p">,</span> <span class="n">callbacks</span><span class="o">=</span><span class="n">callbacks</span><span class="p">,</span> <span class="p">)</span> <span class="n">score</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">,</span> <span class="n">verbose</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 1/20 399/399 ━━━━━━━━━━━━━━━━━━━━ 74s 184ms/step - acc: 0.4980 - loss: 1.3832 - val_acc: 0.9609 - val_loss: 0.1513 Epoch 2/20 399/399 ━━━━━━━━━━━━━━━━━━━━ 74s 186ms/step - acc: 0.9245 - loss: 0.2487 - val_acc: 0.9702 - val_loss: 0.0999 Epoch 3/20 399/399 ━━━━━━━━━━━━━━━━━━━━ 70s 175ms/step - acc: 0.9515 - loss: 0.1647 - val_acc: 0.9816 - val_loss: 0.0608 Epoch 4/20 399/399 ━━━━━━━━━━━━━━━━━━━━ 69s 174ms/step - acc: 0.9622 - loss: 0.1247 - val_acc: 0.9833 - val_loss: 0.0541 Epoch 5/20 399/399 ━━━━━━━━━━━━━━━━━━━━ 68s 171ms/step - acc: 0.9685 - loss: 0.1083 - val_acc: 0.9860 - val_loss: 0.0468 Epoch 6/20 399/399 ━━━━━━━━━━━━━━━━━━━━ 70s 176ms/step - acc: 0.9710 - loss: 0.0955 - val_acc: 0.9897 - val_loss: 0.0400 Epoch 7/20 399/399 ━━━━━━━━━━━━━━━━━━━━ 69s 172ms/step - acc: 0.9742 - loss: 0.0853 - val_acc: 0.9888 - val_loss: 0.0388 Epoch 8/20 399/399 ━━━━━━━━━━━━━━━━━━━━ 68s 169ms/step - acc: 0.9789 - loss: 0.0738 - val_acc: 0.9902 - val_loss: 0.0387 Epoch 9/20 399/399 ━━━━━━━━━━━━━━━━━━━━ 75s 187ms/step - acc: 0.9789 - loss: 0.0691 - val_acc: 0.9907 - val_loss: 0.0341 Epoch 10/20 399/399 ━━━━━━━━━━━━━━━━━━━━ 77s 194ms/step - acc: 0.9806 - loss: 0.0636 - val_acc: 0.9907 - val_loss: 0.0348 Epoch 11/20 399/399 ━━━━━━━━━━━━━━━━━━━━ 74s 186ms/step - acc: 0.9812 - loss: 0.0610 - val_acc: 0.9926 - val_loss: 0.0271 Epoch 12/20 399/399 ━━━━━━━━━━━━━━━━━━━━ 219s 550ms/step - acc: 0.9820 - loss: 0.0590 - val_acc: 0.9912 - val_loss: 0.0294 Epoch 13/20 399/399 ━━━━━━━━━━━━━━━━━━━━ 70s 176ms/step - acc: 0.9843 - loss: 0.0504 - val_acc: 0.9918 - val_loss: 0.0316 </code></pre></div> </div> <p>During training, we were saving a model at the end of each epoch. You can also save the model in its latest state like this:</p> <div class="codehilite"><pre><span></span><code><span class="n">model</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="s2">&quot;final_model.keras&quot;</span><span class="p">)</span> </code></pre></div> <p>And reload it like this:</p> <div class="codehilite"><pre><span></span><code><span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">saving</span><span class="o">.</span><span class="n">load_model</span><span class="p">(</span><span class="s2">&quot;final_model.keras&quot;</span><span class="p">)</span> </code></pre></div> <p>Next, you can query predictions of class probabilities with <code>predict()</code>:</p> <div class="codehilite"><pre><span></span><code><span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">x_test</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 313/313 ━━━━━━━━━━━━━━━━━━━━ 3s 9ms/step </code></pre></div> </div> <p>That's it for the basics!</p> <hr /> <h2 id="writing-crossframework-custom-components">Writing cross-framework custom components</h2> <p>Keras enables you to write custom Layers, Models, Metrics, Losses, and Optimizers that work across TensorFlow, JAX, and PyTorch with the same codebase. Let's take a look at custom layers first.</p> <p>The <code>keras.ops</code> namespace contains:</p> <ul> <li>An implementation of the NumPy API, e.g. <a href="/api/ops/numpy#stack-function"><code>keras.ops.stack</code></a> or <a href="/api/ops/numpy#matmul-function"><code>keras.ops.matmul</code></a>.</li> <li>A set of neural network specific ops that are absent from NumPy, such as <a href="/api/ops/nn#conv-function"><code>keras.ops.conv</code></a> or <a href="/api/ops/nn#binarycrossentropy-function"><code>keras.ops.binary_crossentropy</code></a>.</li> </ul> <p>Let's make a custom <code>Dense</code> layer that works with all backends:</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">MyDense</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">units</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">units</span> <span class="o">=</span> <span class="n">units</span> <span class="bp">self</span><span class="o">.</span><span class="n">activation</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">activations</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">activation</span><span class="p">)</span> <span class="k">def</span> <span class="nf">build</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_shape</span><span class="p">):</span> <span class="n">input_dim</span> <span class="o">=</span> <span class="n">input_shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="bp">self</span><span class="o">.</span><span class="n">w</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">add_weight</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">input_dim</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">units</span><span class="p">),</span> <span class="n">initializer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">initializers</span><span class="o">.</span><span class="n">GlorotNormal</span><span class="p">(),</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;kernel&quot;</span><span class="p">,</span> <span class="n">trainable</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">b</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">add_weight</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">units</span><span class="p">,),</span> <span class="n">initializer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">initializers</span><span class="o">.</span><span class="n">Zeros</span><span class="p">(),</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;bias&quot;</span><span class="p">,</span> <span class="n">trainable</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="c1"># Use Keras ops to create backend-agnostic layers/metrics/etc.</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">w</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">b</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">activation</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> </code></pre></div> <p>Next, let's make a custom <code>Dropout</code> layer that relies on the <code>keras.random</code> namespace:</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">MyDropout</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">rate</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">rate</span> <span class="o">=</span> <span class="n">rate</span> <span class="c1"># Use seed_generator for managing RNG state.</span> <span class="c1"># It is a state element and its seed variable is</span> <span class="c1"># tracked as part of `layer.variables`.</span> <span class="bp">self</span><span class="o">.</span><span class="n">seed_generator</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">SeedGenerator</span><span class="p">(</span><span class="mi">1337</span><span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="c1"># Use `keras.random` for random ops.</span> <span class="k">return</span> <span class="n">keras</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">rate</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">seed_generator</span><span class="p">)</span> </code></pre></div> <p>Next, let's write a custom subclassed model that uses our two custom layers:</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">MyModel</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv_base</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span> <span class="p">[</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">activation</span><span class="o">=</span><span class="s2">&quot;relu&quot;</span><span class="p">),</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">activation</span><span class="o">=</span><span class="s2">&quot;relu&quot;</span><span class="p">),</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">MaxPooling2D</span><span class="p">(</span><span class="n">pool_size</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">)),</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">activation</span><span class="o">=</span><span class="s2">&quot;relu&quot;</span><span class="p">),</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">activation</span><span class="o">=</span><span class="s2">&quot;relu&quot;</span><span class="p">),</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">GlobalAveragePooling2D</span><span class="p">(),</span> <span class="p">]</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">dp</span> <span class="o">=</span> <span class="n">MyDropout</span><span class="p">(</span><span class="mf">0.5</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense</span> <span class="o">=</span> <span class="n">MyDense</span><span class="p">(</span><span class="n">num_classes</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">&quot;softmax&quot;</span><span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv_base</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dp</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> </code></pre></div> <p>Let's compile it and fit it:</p> <div class="codehilite"><pre><span></span><code><span class="n">model</span> <span class="o">=</span> <span class="n">MyModel</span><span class="p">(</span><span class="n">num_classes</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">loss</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">SparseCategoricalCrossentropy</span><span class="p">(),</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">learning_rate</span><span class="o">=</span><span class="mf">1e-3</span><span class="p">),</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span> <span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">SparseCategoricalAccuracy</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">&quot;acc&quot;</span><span class="p">),</span> <span class="p">],</span> <span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span> <span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="c1"># For speed</span> <span class="n">validation_split</span><span class="o">=</span><span class="mf">0.15</span><span class="p">,</span> <span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 399/399 ━━━━━━━━━━━━━━━━━━━━ 70s 174ms/step - acc: 0.5104 - loss: 1.3473 - val_acc: 0.9256 - val_loss: 0.2484 &lt;keras.src.callbacks.history.History at 0x105608670&gt; </code></pre></div> </div> <hr /> <h2 id="training-models-on-arbitrary-data-sources">Training models on arbitrary data sources</h2> <p>All Keras models can be trained and evaluated on a wide variety of data sources, independently of the backend you're using. This includes:</p> <ul> <li>NumPy arrays</li> <li>Pandas dataframes</li> <li>TensorFlow <a href="https://www.tensorflow.org/api_docs/python/tf/data/Dataset"><code>tf.data.Dataset</code></a> objects</li> <li>PyTorch <code>DataLoader</code> objects</li> <li>Keras <code>PyDataset</code> objects</li> </ul> <p>They all work whether you're using TensorFlow, JAX, or PyTorch as your Keras backend.</p> <p>Let's try it out with PyTorch <code>DataLoaders</code>:</p> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">torch</span> <span class="c1"># Create a TensorDataset</span> <span class="n">train_torch_dataset</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">TensorDataset</span><span class="p">(</span> <span class="n">torch</span><span class="o">.</span><span class="n">from_numpy</span><span class="p">(</span><span class="n">x_train</span><span class="p">),</span> <span class="n">torch</span><span class="o">.</span><span class="n">from_numpy</span><span class="p">(</span><span class="n">y_train</span><span class="p">)</span> <span class="p">)</span> <span class="n">val_torch_dataset</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">TensorDataset</span><span class="p">(</span> <span class="n">torch</span><span class="o">.</span><span class="n">from_numpy</span><span class="p">(</span><span class="n">x_test</span><span class="p">),</span> <span class="n">torch</span><span class="o">.</span><span class="n">from_numpy</span><span class="p">(</span><span class="n">y_test</span><span class="p">)</span> <span class="p">)</span> <span class="c1"># Create a DataLoader</span> <span class="n">train_dataloader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span> <span class="n">train_torch_dataset</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span> <span class="p">)</span> <span class="n">val_dataloader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span> <span class="n">val_torch_dataset</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">False</span> <span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">MyModel</span><span class="p">(</span><span class="n">num_classes</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">loss</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">SparseCategoricalCrossentropy</span><span class="p">(),</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">learning_rate</span><span class="o">=</span><span class="mf">1e-3</span><span class="p">),</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span> <span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">SparseCategoricalAccuracy</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">&quot;acc&quot;</span><span class="p">),</span> <span class="p">],</span> <span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">train_dataloader</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="n">val_dataloader</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 469/469 ━━━━━━━━━━━━━━━━━━━━ 81s 172ms/step - acc: 0.5502 - loss: 1.2550 - val_acc: 0.9419 - val_loss: 0.1972 &lt;keras.src.callbacks.history.History at 0x2b3385480&gt; </code></pre></div> </div> <p>Now let's try this out with <a href="https://www.tensorflow.org/api_docs/python/tf/data"><code>tf.data</code></a>:</p> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="nn">tf</span> <span class="n">train_dataset</span> <span class="o">=</span> <span class="p">(</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">from_tensor_slices</span><span class="p">((</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">))</span> <span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span> <span class="o">.</span><span class="n">prefetch</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">AUTOTUNE</span><span class="p">)</span> <span class="p">)</span> <span class="n">test_dataset</span> <span class="o">=</span> <span class="p">(</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">from_tensor_slices</span><span class="p">((</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">))</span> <span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span> <span class="o">.</span><span class="n">prefetch</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">AUTOTUNE</span><span class="p">)</span> <span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">MyModel</span><span class="p">(</span><span class="n">num_classes</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">loss</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">SparseCategoricalCrossentropy</span><span class="p">(),</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">learning_rate</span><span class="o">=</span><span class="mf">1e-3</span><span class="p">),</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span> <span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">SparseCategoricalAccuracy</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">&quot;acc&quot;</span><span class="p">),</span> <span class="p">],</span> <span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">train_dataset</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="n">test_dataset</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 469/469 ━━━━━━━━━━━━━━━━━━━━ 81s 172ms/step - acc: 0.5771 - loss: 1.1948 - val_acc: 0.9229 - val_loss: 0.2502 &lt;keras.src.callbacks.history.History at 0x2b33e7df0&gt; </code></pre></div> </div> <hr /> <h2 id="further-reading">Further reading</h2> <p>This concludes our short overview of the new multi-backend capabilities of Keras 3. Next, you can learn about:</p> <h3 id="how-to-customize-what-happens-in-fit">How to customize what happens in <code>fit()</code></h3> <p>Want to implement a non-standard training algorithm yourself but still want to benefit from the power and usability of <code>fit()</code>? It's easy to customize <code>fit()</code> to support arbitrary use cases:</p> <ul> <li><a href="http://keras.io/guides/custom_train_step_in_tensorflow/">Customizing what happens in <code>fit()</code> with TensorFlow</a></li> <li><a href="http://keras.io/guides/custom_train_step_in_jax/">Customizing what happens in <code>fit()</code> with JAX</a></li> <li><a href="http://keras.io/guides/custom_train_step_in_torch/">Customizing what happens in <code>fit()</code> with PyTorch</a></li> </ul> <hr /> <h2 id="how-to-write-custom-training-loops">How to write custom training loops</h2> <ul> <li><a href="http://keras.io/guides/writing_a_custom_training_loop_in_tensorflow/">Writing a training loop from scratch in TensorFlow</a></li> <li><a href="http://keras.io/guides/writing_a_custom_training_loop_in_jax/">Writing a training loop from scratch in JAX</a></li> <li><a href="http://keras.io/guides/writing_a_custom_training_loop_in_torch/">Writing a training loop from scratch in PyTorch</a></li> </ul> <hr /> <h2 id="how-to-distribute-training">How to distribute training</h2> <ul> <li><a href="http://keras.io/guides/distributed_training_with_tensorflow/">Guide to distributed training with TensorFlow</a></li> <li><a href="https://github.com/keras-team/keras/blob/master/examples/demo_jax_distributed.py">JAX distributed training example</a></li> <li><a href="https://github.com/keras-team/keras/blob/master/examples/demo_torch_multi_gpu.py">PyTorch distributed training example</a></li> </ul> <p>Enjoy the library! 🚀</p> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#introduction-to-keras-for-engineers'>Introduction to Keras for engineers</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#setup'>Setup</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#a-first-example-a-mnist-convnet'>A first example: A MNIST convnet</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#writing-crossframework-custom-components'>Writing cross-framework custom components</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#training-models-on-arbitrary-data-sources'>Training models on arbitrary data sources</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#further-reading'>Further reading</a> </div> <div class='k-outline-depth-3'> <a href='#how-to-customize-what-happens-in-fit'>How to customize what happens in <code>fit()</code></a> </div> <div class='k-outline-depth-2'> ◆ <a href='#how-to-write-custom-training-loops'>How to write custom training loops</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#how-to-distribute-training'>How to distribute training</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>

Pages: 1 2 3 4 5 6 7 8 9 10