CINXE.COM
The Functional API
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/guides/functional_api/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: The Functional API"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: The Functional API"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>The Functional API</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link active" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-sublink active" href="/guides/functional_api/">The Functional API</a> <a class="nav-sublink" href="/guides/sequential_model/">The Sequential model</a> <a class="nav-sublink" href="/guides/making_new_layers_and_models_via_subclassing/">Making new layers & models via subclassing</a> <a class="nav-sublink" href="/guides/training_with_built_in_methods/">Training & evaluation with the built-in methods</a> <a class="nav-sublink" href="/guides/custom_train_step_in_jax/">Customizing `fit()` with JAX</a> <a class="nav-sublink" href="/guides/custom_train_step_in_tensorflow/">Customizing `fit()` with TensorFlow</a> <a class="nav-sublink" href="/guides/custom_train_step_in_torch/">Customizing `fit()` with PyTorch</a> <a class="nav-sublink" href="/guides/writing_a_custom_training_loop_in_jax/">Writing a custom training loop in JAX</a> <a class="nav-sublink" href="/guides/writing_a_custom_training_loop_in_tensorflow/">Writing a custom training loop in TensorFlow</a> <a class="nav-sublink" href="/guides/writing_a_custom_training_loop_in_torch/">Writing a custom training loop in PyTorch</a> <a class="nav-sublink" href="/guides/serialization_and_saving/">Serialization & saving</a> <a class="nav-sublink" href="/guides/customizing_saving_and_serialization/">Customizing saving & serialization</a> <a class="nav-sublink" href="/guides/writing_your_own_callbacks/">Writing your own callbacks</a> <a class="nav-sublink" href="/guides/transfer_learning/">Transfer learning & fine-tuning</a> <a class="nav-sublink" href="/guides/distributed_training_with_jax/">Distributed training with JAX</a> <a class="nav-sublink" href="/guides/distributed_training_with_tensorflow/">Distributed training with TensorFlow</a> <a class="nav-sublink" href="/guides/distributed_training_with_torch/">Distributed training with PyTorch</a> <a class="nav-sublink" href="/guides/distribution/">Distributed training with Keras 3</a> <a class="nav-sublink" href="/guides/migrating_to_keras_3/">Migrating Keras 2 code to Keras 3</a> <a class="nav-link" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparam Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/guides/'>Developer guides</a> / The Functional API </div> <div class='k-content'> <h1 id="the-functional-api">The Functional API</h1> <p><strong>Author:</strong> <a href="https://twitter.com/fchollet">fchollet</a><br> <strong>Date created:</strong> 2019/03/01<br> <strong>Last modified:</strong> 2023/06/25<br> <strong>Description:</strong> Complete guide to the functional API.</p> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/guides/ipynb/functional_api.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/guides/functional_api.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="setup">Setup</h2> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> <span class="kn">import</span> <span class="nn">keras</span> <span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">layers</span> <span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">ops</span> </code></pre></div> <hr /> <h2 id="introduction">Introduction</h2> <p>The Keras <em>functional API</em> is a way to create models that are more flexible than the <a href="/api/models/sequential#sequential-class"><code>keras.Sequential</code></a> API. The functional API can handle models with non-linear topology, shared layers, and even multiple inputs or outputs.</p> <p>The main idea is that a deep learning model is usually a directed acyclic graph (DAG) of layers. So the functional API is a way to build <em>graphs of layers</em>.</p> <p>Consider the following model:</p> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>(input: 784-dimensional vectors) ↧ [Dense (64 units, relu activation)] ↧ [Dense (64 units, relu activation)] ↧ [Dense (10 units, softmax activation)] ↧ (output: logits of a probability distribution over 10 classes) </code></pre></div> </div> <p>This is a basic graph with three layers. To build this model using the functional API, start by creating an input node:</p> <div class="codehilite"><pre><span></span><code><span class="n">inputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">784</span><span class="p">,))</span> </code></pre></div> <p>The shape of the data is set as a 784-dimensional vector. The batch size is always omitted since only the shape of each sample is specified.</p> <p>If, for example, you have an image input with a shape of <code>(32, 32, 3)</code>, you would use:</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Just for demonstration purposes.</span> <span class="n">img_inputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span> </code></pre></div> <p>The <code>inputs</code> that is returned contains information about the shape and <code>dtype</code> of the input data that you feed to your model. Here's the shape:</p> <div class="codehilite"><pre><span></span><code><span class="n">inputs</span><span class="o">.</span><span class="n">shape</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>(None, 784) </code></pre></div> </div> <p>Here's the dtype:</p> <div class="codehilite"><pre><span></span><code><span class="n">inputs</span><span class="o">.</span><span class="n">dtype</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>'float32' </code></pre></div> </div> <p>You create a new node in the graph of layers by calling a layer on this <code>inputs</code> object:</p> <div class="codehilite"><pre><span></span><code><span class="n">dense</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">dense</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> </code></pre></div> <p>The "layer call" action is like drawing an arrow from "inputs" to this layer you created. You're "passing" the inputs to the <code>dense</code> layer, and you get <code>x</code> as the output.</p> <p>Let's add a few more layers to the graph of layers:</p> <div class="codehilite"><pre><span></span><code><span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">10</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> </code></pre></div> <p>At this point, you can create a <code>Model</code> by specifying its inputs and outputs in the graph of layers:</p> <div class="codehilite"><pre><span></span><code><span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="o">=</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="o">=</span><span class="n">outputs</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"mnist_model"</span><span class="p">)</span> </code></pre></div> <p>Let's check out what the model summary looks like:</p> <div class="codehilite"><pre><span></span><code><span class="n">model</span><span class="o">.</span><span class="n">summary</span><span class="p">()</span> </code></pre></div> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold">Model: "mnist_model"</span> </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃<span style="font-weight: bold"> Layer (type) </span>┃<span style="font-weight: bold"> Output Shape </span>┃<span style="font-weight: bold"> Param # </span>┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ input_layer (<span style="color: #0087ff; text-decoration-color: #0087ff">InputLayer</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">784</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense (<span style="color: #0087ff; text-decoration-color: #0087ff">Dense</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">50,240</span> │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_1 (<span style="color: #0087ff; text-decoration-color: #0087ff">Dense</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">4,160</span> │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ dense_2 (<span style="color: #0087ff; text-decoration-color: #0087ff">Dense</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">10</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">650</span> │ └─────────────────────────────────┴────────────────────────┴───────────────┘ </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Total params: </span><span style="color: #00af00; text-decoration-color: #00af00">55,050</span> (215.04 KB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">55,050</span> (215.04 KB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Non-trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">0</span> (0.00 B) </pre> <p>You can also plot the model as a graph:</p> <div class="codehilite"><pre><span></span><code><span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">plot_model</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="s2">"my_first_model.png"</span><span class="p">)</span> </code></pre></div> <p><img alt="png" src="/img/guides/functional_api/functional_api_20_0.png" /></p> <p>And, optionally, display the input and output shapes of each layer in the plotted graph:</p> <div class="codehilite"><pre><span></span><code><span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">plot_model</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="s2">"my_first_model_with_shape_info.png"</span><span class="p">,</span> <span class="n">show_shapes</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> </code></pre></div> <p><img alt="png" src="/img/guides/functional_api/functional_api_22_0.png" /></p> <p>This figure and the code are almost identical. In the code version, the connection arrows are replaced by the call operation.</p> <p>A "graph of layers" is an intuitive mental image for a deep learning model, and the functional API is a way to create models that closely mirrors this.</p> <hr /> <h2 id="training-evaluation-and-inference">Training, evaluation, and inference</h2> <p>Training, evaluation, and inference work exactly in the same way for models built using the functional API as for <code>Sequential</code> models.</p> <p>The <code>Model</code> class offers a built-in training loop (the <code>fit()</code> method) and a built-in evaluation loop (the <code>evaluate()</code> method). Note that you can easily customize these loops to implement your own training routines. See also the guides on customizing what happens in <code>fit()</code>:</p> <ul> <li><a href="/guides/custom_train_step_in_tensorflow/">Writing a custom train step with TensorFlow</a></li> <li><a href="/guides/custom_train_step_in_jax/">Writing a custom train step with JAX</a></li> <li><a href="/guides/custom_train_step_in_torch/">Writing a custom train step with PyTorch</a></li> </ul> <p>Here, load the MNIST image data, reshape it into vectors, fit the model on the data (while monitoring performance on a validation split), then evaluate the model on the test data:</p> <div class="codehilite"><pre><span></span><code><span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">),</span> <span class="p">(</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">)</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">datasets</span><span class="o">.</span><span class="n">mnist</span><span class="o">.</span><span class="n">load_data</span><span class="p">()</span> <span class="n">x_train</span> <span class="o">=</span> <span class="n">x_train</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">60000</span><span class="p">,</span> <span class="mi">784</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">"float32"</span><span class="p">)</span> <span class="o">/</span> <span class="mi">255</span> <span class="n">x_test</span> <span class="o">=</span> <span class="n">x_test</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">10000</span><span class="p">,</span> <span class="mi">784</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">"float32"</span><span class="p">)</span> <span class="o">/</span> <span class="mi">255</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">loss</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">SparseCategoricalCrossentropy</span><span class="p">(</span><span class="n">from_logits</span><span class="o">=</span><span class="kc">True</span><span class="p">),</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">RMSprop</span><span class="p">(),</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="s2">"accuracy"</span><span class="p">],</span> <span class="p">)</span> <span class="n">history</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">validation_split</span><span class="o">=</span><span class="mf">0.2</span><span class="p">)</span> <span class="n">test_scores</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">,</span> <span class="n">verbose</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Test loss:"</span><span class="p">,</span> <span class="n">test_scores</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Test accuracy:"</span><span class="p">,</span> <span class="n">test_scores</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 1/2 750/750 ━━━━━━━━━━━━━━━━━━━━ 1s 863us/step - accuracy: 0.8425 - loss: 0.5733 - val_accuracy: 0.9496 - val_loss: 0.1711 Epoch 2/2 750/750 ━━━━━━━━━━━━━━━━━━━━ 1s 859us/step - accuracy: 0.9509 - loss: 0.1641 - val_accuracy: 0.9578 - val_loss: 0.1396 313/313 - 0s - 341us/step - accuracy: 0.9613 - loss: 0.1288 Test loss: 0.12876172363758087 Test accuracy: 0.9613000154495239 </code></pre></div> </div> <p>For further reading, see the <a href="/guides/training_with_built_in_methods/">training and evaluation</a> guide.</p> <hr /> <h2 id="save-and-serialize">Save and serialize</h2> <p>Saving the model and serialization work the same way for models built using the functional API as they do for <code>Sequential</code> models. The standard way to save a functional model is to call <code>model.save()</code> to save the entire model as a single file. You can later recreate the same model from this file, even if the code that built the model is no longer available.</p> <p>This saved file includes the: - model architecture - model weight values (that were learned during training) - model training config, if any (as passed to <code>compile()</code>) - optimizer and its state, if any (to restart training where you left off)</p> <div class="codehilite"><pre><span></span><code><span class="n">model</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="s2">"my_model.keras"</span><span class="p">)</span> <span class="k">del</span> <span class="n">model</span> <span class="c1"># Recreate the exact same model purely from the file:</span> <span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">load_model</span><span class="p">(</span><span class="s2">"my_model.keras"</span><span class="p">)</span> </code></pre></div> <p>For details, read the model <a href="/guides/serialization_and_saving/">serialization & saving</a> guide.</p> <hr /> <h2 id="use-the-same-graph-of-layers-to-define-multiple-models">Use the same graph of layers to define multiple models</h2> <p>In the functional API, models are created by specifying their inputs and outputs in a graph of layers. That means that a single graph of layers can be used to generate multiple models.</p> <p>In the example below, you use the same stack of layers to instantiate two models: an <code>encoder</code> model that turns image inputs into 16-dimensional vectors, and an end-to-end <code>autoencoder</code> model for training.</p> <div class="codehilite"><pre><span></span><code><span class="n">encoder_input</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">28</span><span class="p">,</span> <span class="mi">28</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"img"</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">16</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)(</span><span class="n">encoder_input</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">MaxPooling2D</span><span class="p">(</span><span class="mi">3</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">16</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">encoder_output</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">GlobalMaxPooling2D</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> <span class="n">encoder</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">encoder_input</span><span class="p">,</span> <span class="n">encoder_output</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"encoder"</span><span class="p">)</span> <span class="n">encoder</span><span class="o">.</span><span class="n">summary</span><span class="p">()</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Reshape</span><span class="p">((</span><span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">1</span><span class="p">))(</span><span class="n">encoder_output</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2DTranspose</span><span class="p">(</span><span class="mi">16</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2DTranspose</span><span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">UpSampling2D</span><span class="p">(</span><span class="mi">3</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2DTranspose</span><span class="p">(</span><span class="mi">16</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">decoder_output</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2DTranspose</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">autoencoder</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">encoder_input</span><span class="p">,</span> <span class="n">decoder_output</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"autoencoder"</span><span class="p">)</span> <span class="n">autoencoder</span><span class="o">.</span><span class="n">summary</span><span class="p">()</span> </code></pre></div> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold">Model: "encoder"</span> </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃<span style="font-weight: bold"> Layer (type) </span>┃<span style="font-weight: bold"> Output Shape </span>┃<span style="font-weight: bold"> Param # </span>┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ img (<span style="color: #0087ff; text-decoration-color: #0087ff">InputLayer</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">28</span>, <span style="color: #00af00; text-decoration-color: #00af00">28</span>, <span style="color: #00af00; text-decoration-color: #00af00">1</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv2D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">26</span>, <span style="color: #00af00; text-decoration-color: #00af00">26</span>, <span style="color: #00af00; text-decoration-color: #00af00">16</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">160</span> │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_1 (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv2D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">24</span>, <span style="color: #00af00; text-decoration-color: #00af00">24</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">4,640</span> │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d (<span style="color: #0087ff; text-decoration-color: #0087ff">MaxPooling2D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">8</span>, <span style="color: #00af00; text-decoration-color: #00af00">8</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_2 (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv2D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">6</span>, <span style="color: #00af00; text-decoration-color: #00af00">6</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">9,248</span> │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_3 (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv2D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">4</span>, <span style="color: #00af00; text-decoration-color: #00af00">4</span>, <span style="color: #00af00; text-decoration-color: #00af00">16</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">4,624</span> │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ global_max_pooling2d │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">16</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">GlobalMaxPooling2D</span>) │ │ │ └─────────────────────────────────┴────────────────────────┴───────────────┘ </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Total params: </span><span style="color: #00af00; text-decoration-color: #00af00">18,672</span> (72.94 KB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">18,672</span> (72.94 KB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Non-trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">0</span> (0.00 B) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold">Model: "autoencoder"</span> </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃<span style="font-weight: bold"> Layer (type) </span>┃<span style="font-weight: bold"> Output Shape </span>┃<span style="font-weight: bold"> Param # </span>┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ img (<span style="color: #0087ff; text-decoration-color: #0087ff">InputLayer</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">28</span>, <span style="color: #00af00; text-decoration-color: #00af00">28</span>, <span style="color: #00af00; text-decoration-color: #00af00">1</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv2D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">26</span>, <span style="color: #00af00; text-decoration-color: #00af00">26</span>, <span style="color: #00af00; text-decoration-color: #00af00">16</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">160</span> │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_1 (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv2D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">24</span>, <span style="color: #00af00; text-decoration-color: #00af00">24</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">4,640</span> │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d (<span style="color: #0087ff; text-decoration-color: #0087ff">MaxPooling2D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">8</span>, <span style="color: #00af00; text-decoration-color: #00af00">8</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_2 (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv2D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">6</span>, <span style="color: #00af00; text-decoration-color: #00af00">6</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">9,248</span> │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_3 (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv2D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">4</span>, <span style="color: #00af00; text-decoration-color: #00af00">4</span>, <span style="color: #00af00; text-decoration-color: #00af00">16</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">4,624</span> │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ global_max_pooling2d │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">16</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">GlobalMaxPooling2D</span>) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ reshape (<span style="color: #0087ff; text-decoration-color: #0087ff">Reshape</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">4</span>, <span style="color: #00af00; text-decoration-color: #00af00">4</span>, <span style="color: #00af00; text-decoration-color: #00af00">1</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_transpose │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">6</span>, <span style="color: #00af00; text-decoration-color: #00af00">6</span>, <span style="color: #00af00; text-decoration-color: #00af00">16</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">160</span> │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv2DTranspose</span>) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_transpose_1 │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">8</span>, <span style="color: #00af00; text-decoration-color: #00af00">8</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">4,640</span> │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv2DTranspose</span>) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ up_sampling2d (<span style="color: #0087ff; text-decoration-color: #0087ff">UpSampling2D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">24</span>, <span style="color: #00af00; text-decoration-color: #00af00">24</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_transpose_2 │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">26</span>, <span style="color: #00af00; text-decoration-color: #00af00">26</span>, <span style="color: #00af00; text-decoration-color: #00af00">16</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">4,624</span> │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv2DTranspose</span>) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_transpose_3 │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">28</span>, <span style="color: #00af00; text-decoration-color: #00af00">28</span>, <span style="color: #00af00; text-decoration-color: #00af00">1</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">145</span> │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv2DTranspose</span>) │ │ │ └─────────────────────────────────┴────────────────────────┴───────────────┘ </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Total params: </span><span style="color: #00af00; text-decoration-color: #00af00">28,241</span> (110.32 KB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">28,241</span> (110.32 KB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Non-trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">0</span> (0.00 B) </pre> <p>Here, the decoding architecture is strictly symmetrical to the encoding architecture, so the output shape is the same as the input shape <code>(28, 28, 1)</code>.</p> <p>The reverse of a <code>Conv2D</code> layer is a <code>Conv2DTranspose</code> layer, and the reverse of a <code>MaxPooling2D</code> layer is an <code>UpSampling2D</code> layer.</p> <hr /> <h2 id="all-models-are-callable-just-like-layers">All models are callable, just like layers</h2> <p>You can treat any model as if it were a layer by invoking it on an <code>Input</code> or on the output of another layer. By calling a model you aren't just reusing the architecture of the model, you're also reusing its weights.</p> <p>To see this in action, here's a different take on the autoencoder example that creates an encoder model, a decoder model, and chains them in two calls to obtain the autoencoder model:</p> <div class="codehilite"><pre><span></span><code><span class="n">encoder_input</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">28</span><span class="p">,</span> <span class="mi">28</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"original_img"</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">16</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)(</span><span class="n">encoder_input</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">MaxPooling2D</span><span class="p">(</span><span class="mi">3</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">16</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">encoder_output</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">GlobalMaxPooling2D</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> <span class="n">encoder</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">encoder_input</span><span class="p">,</span> <span class="n">encoder_output</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"encoder"</span><span class="p">)</span> <span class="n">encoder</span><span class="o">.</span><span class="n">summary</span><span class="p">()</span> <span class="n">decoder_input</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">16</span><span class="p">,),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"encoded_img"</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Reshape</span><span class="p">((</span><span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">1</span><span class="p">))(</span><span class="n">decoder_input</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2DTranspose</span><span class="p">(</span><span class="mi">16</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2DTranspose</span><span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">UpSampling2D</span><span class="p">(</span><span class="mi">3</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2DTranspose</span><span class="p">(</span><span class="mi">16</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">decoder_output</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2DTranspose</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">decoder</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">decoder_input</span><span class="p">,</span> <span class="n">decoder_output</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"decoder"</span><span class="p">)</span> <span class="n">decoder</span><span class="o">.</span><span class="n">summary</span><span class="p">()</span> <span class="n">autoencoder_input</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">28</span><span class="p">,</span> <span class="mi">28</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"img"</span><span class="p">)</span> <span class="n">encoded_img</span> <span class="o">=</span> <span class="n">encoder</span><span class="p">(</span><span class="n">autoencoder_input</span><span class="p">)</span> <span class="n">decoded_img</span> <span class="o">=</span> <span class="n">decoder</span><span class="p">(</span><span class="n">encoded_img</span><span class="p">)</span> <span class="n">autoencoder</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">autoencoder_input</span><span class="p">,</span> <span class="n">decoded_img</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"autoencoder"</span><span class="p">)</span> <span class="n">autoencoder</span><span class="o">.</span><span class="n">summary</span><span class="p">()</span> </code></pre></div> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold">Model: "encoder"</span> </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃<span style="font-weight: bold"> Layer (type) </span>┃<span style="font-weight: bold"> Output Shape </span>┃<span style="font-weight: bold"> Param # </span>┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ original_img (<span style="color: #0087ff; text-decoration-color: #0087ff">InputLayer</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">28</span>, <span style="color: #00af00; text-decoration-color: #00af00">28</span>, <span style="color: #00af00; text-decoration-color: #00af00">1</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_4 (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv2D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">26</span>, <span style="color: #00af00; text-decoration-color: #00af00">26</span>, <span style="color: #00af00; text-decoration-color: #00af00">16</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">160</span> │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_5 (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv2D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">24</span>, <span style="color: #00af00; text-decoration-color: #00af00">24</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">4,640</span> │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ max_pooling2d_1 (<span style="color: #0087ff; text-decoration-color: #0087ff">MaxPooling2D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">8</span>, <span style="color: #00af00; text-decoration-color: #00af00">8</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_6 (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv2D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">6</span>, <span style="color: #00af00; text-decoration-color: #00af00">6</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">9,248</span> │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_7 (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv2D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">4</span>, <span style="color: #00af00; text-decoration-color: #00af00">4</span>, <span style="color: #00af00; text-decoration-color: #00af00">16</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">4,624</span> │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ global_max_pooling2d_1 │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">16</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">GlobalMaxPooling2D</span>) │ │ │ └─────────────────────────────────┴────────────────────────┴───────────────┘ </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Total params: </span><span style="color: #00af00; text-decoration-color: #00af00">18,672</span> (72.94 KB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">18,672</span> (72.94 KB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Non-trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">0</span> (0.00 B) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold">Model: "decoder"</span> </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃<span style="font-weight: bold"> Layer (type) </span>┃<span style="font-weight: bold"> Output Shape </span>┃<span style="font-weight: bold"> Param # </span>┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ encoded_img (<span style="color: #0087ff; text-decoration-color: #0087ff">InputLayer</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">16</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ reshape_1 (<span style="color: #0087ff; text-decoration-color: #0087ff">Reshape</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">4</span>, <span style="color: #00af00; text-decoration-color: #00af00">4</span>, <span style="color: #00af00; text-decoration-color: #00af00">1</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_transpose_4 │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">6</span>, <span style="color: #00af00; text-decoration-color: #00af00">6</span>, <span style="color: #00af00; text-decoration-color: #00af00">16</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">160</span> │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv2DTranspose</span>) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_transpose_5 │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">8</span>, <span style="color: #00af00; text-decoration-color: #00af00">8</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">4,640</span> │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv2DTranspose</span>) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ up_sampling2d_1 (<span style="color: #0087ff; text-decoration-color: #0087ff">UpSampling2D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">24</span>, <span style="color: #00af00; text-decoration-color: #00af00">24</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_transpose_6 │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">26</span>, <span style="color: #00af00; text-decoration-color: #00af00">26</span>, <span style="color: #00af00; text-decoration-color: #00af00">16</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">4,624</span> │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv2DTranspose</span>) │ │ │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ conv2d_transpose_7 │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">28</span>, <span style="color: #00af00; text-decoration-color: #00af00">28</span>, <span style="color: #00af00; text-decoration-color: #00af00">1</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">145</span> │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv2DTranspose</span>) │ │ │ └─────────────────────────────────┴────────────────────────┴───────────────┘ </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Total params: </span><span style="color: #00af00; text-decoration-color: #00af00">9,569</span> (37.38 KB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">9,569</span> (37.38 KB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Non-trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">0</span> (0.00 B) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold">Model: "autoencoder"</span> </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃<span style="font-weight: bold"> Layer (type) </span>┃<span style="font-weight: bold"> Output Shape </span>┃<span style="font-weight: bold"> Param # </span>┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ │ img (<span style="color: #0087ff; text-decoration-color: #0087ff">InputLayer</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">28</span>, <span style="color: #00af00; text-decoration-color: #00af00">28</span>, <span style="color: #00af00; text-decoration-color: #00af00">1</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ encoder (<span style="color: #0087ff; text-decoration-color: #0087ff">Functional</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">16</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">18,672</span> │ ├─────────────────────────────────┼────────────────────────┼───────────────┤ │ decoder (<span style="color: #0087ff; text-decoration-color: #0087ff">Functional</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">28</span>, <span style="color: #00af00; text-decoration-color: #00af00">28</span>, <span style="color: #00af00; text-decoration-color: #00af00">1</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">9,569</span> │ └─────────────────────────────────┴────────────────────────┴───────────────┘ </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Total params: </span><span style="color: #00af00; text-decoration-color: #00af00">28,241</span> (110.32 KB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">28,241</span> (110.32 KB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Non-trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">0</span> (0.00 B) </pre> <p>As you can see, the model can be nested: a model can contain sub-models (since a model is just like a layer). A common use case for model nesting is <em>ensembling</em>. For example, here's how to ensemble a set of models into a single model that averages their predictions:</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">get_model</span><span class="p">():</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">128</span><span class="p">,))</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">1</span><span class="p">)(</span><span class="n">inputs</span><span class="p">)</span> <span class="k">return</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="p">)</span> <span class="n">model1</span> <span class="o">=</span> <span class="n">get_model</span><span class="p">()</span> <span class="n">model2</span> <span class="o">=</span> <span class="n">get_model</span><span class="p">()</span> <span class="n">model3</span> <span class="o">=</span> <span class="n">get_model</span><span class="p">()</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">128</span><span class="p">,))</span> <span class="n">y1</span> <span class="o">=</span> <span class="n">model1</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">y2</span> <span class="o">=</span> <span class="n">model2</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">y3</span> <span class="o">=</span> <span class="n">model3</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">average</span><span class="p">([</span><span class="n">y1</span><span class="p">,</span> <span class="n">y2</span><span class="p">,</span> <span class="n">y3</span><span class="p">])</span> <span class="n">ensemble_model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="o">=</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="o">=</span><span class="n">outputs</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="manipulate-complex-graph-topologies">Manipulate complex graph topologies</h2> <h3 id="models-with-multiple-inputs-and-outputs">Models with multiple inputs and outputs</h3> <p>The functional API makes it easy to manipulate multiple inputs and outputs. This cannot be handled with the <code>Sequential</code> API.</p> <p>For example, if you're building a system for ranking customer issue tickets by priority and routing them to the correct department, then the model will have three inputs:</p> <ul> <li>the title of the ticket (text input),</li> <li>the text body of the ticket (text input), and</li> <li>any tags added by the user (categorical input)</li> </ul> <p>This model will have two outputs:</p> <ul> <li>the priority score between 0 and 1 (scalar sigmoid output), and</li> <li>the department that should handle the ticket (softmax output over the set of departments).</li> </ul> <p>You can build this model in a few lines with the functional API:</p> <div class="codehilite"><pre><span></span><code><span class="n">num_tags</span> <span class="o">=</span> <span class="mi">12</span> <span class="c1"># Number of unique issue tags</span> <span class="n">num_words</span> <span class="o">=</span> <span class="mi">10000</span> <span class="c1"># Size of vocabulary obtained when preprocessing text data</span> <span class="n">num_departments</span> <span class="o">=</span> <span class="mi">4</span> <span class="c1"># Number of departments for predictions</span> <span class="n">title_input</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="kc">None</span><span class="p">,),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"title"</span> <span class="p">)</span> <span class="c1"># Variable-length sequence of ints</span> <span class="n">body_input</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="kc">None</span><span class="p">,),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"body"</span><span class="p">)</span> <span class="c1"># Variable-length sequence of ints</span> <span class="n">tags_input</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">num_tags</span><span class="p">,),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"tags"</span> <span class="p">)</span> <span class="c1"># Binary vectors of size `num_tags`</span> <span class="c1"># Embed each word in the title into a 64-dimensional vector</span> <span class="n">title_features</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="n">num_words</span><span class="p">,</span> <span class="mi">64</span><span class="p">)(</span><span class="n">title_input</span><span class="p">)</span> <span class="c1"># Embed each word in the text into a 64-dimensional vector</span> <span class="n">body_features</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="n">num_words</span><span class="p">,</span> <span class="mi">64</span><span class="p">)(</span><span class="n">body_input</span><span class="p">)</span> <span class="c1"># Reduce sequence of embedded words in the title into a single 128-dimensional vector</span> <span class="n">title_features</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">LSTM</span><span class="p">(</span><span class="mi">128</span><span class="p">)(</span><span class="n">title_features</span><span class="p">)</span> <span class="c1"># Reduce sequence of embedded words in the body into a single 32-dimensional vector</span> <span class="n">body_features</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">LSTM</span><span class="p">(</span><span class="mi">32</span><span class="p">)(</span><span class="n">body_features</span><span class="p">)</span> <span class="c1"># Merge all available features into a single large vector via concatenation</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">concatenate</span><span class="p">([</span><span class="n">title_features</span><span class="p">,</span> <span class="n">body_features</span><span class="p">,</span> <span class="n">tags_input</span><span class="p">])</span> <span class="c1"># Stick a logistic regression for priority prediction on top of the features</span> <span class="n">priority_pred</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"priority"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># Stick a department classifier on top of the features</span> <span class="n">department_pred</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">num_departments</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"department"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># Instantiate an end-to-end model predicting both priority and department</span> <span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span> <span class="n">inputs</span><span class="o">=</span><span class="p">[</span><span class="n">title_input</span><span class="p">,</span> <span class="n">body_input</span><span class="p">,</span> <span class="n">tags_input</span><span class="p">],</span> <span class="n">outputs</span><span class="o">=</span><span class="p">{</span><span class="s2">"priority"</span><span class="p">:</span> <span class="n">priority_pred</span><span class="p">,</span> <span class="s2">"department"</span><span class="p">:</span> <span class="n">department_pred</span><span class="p">},</span> <span class="p">)</span> </code></pre></div> <p>Now plot the model:</p> <div class="codehilite"><pre><span></span><code><span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">plot_model</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="s2">"multi_input_and_output_model.png"</span><span class="p">,</span> <span class="n">show_shapes</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> </code></pre></div> <p><img alt="png" src="/img/guides/functional_api/functional_api_40_0.png" /></p> <p>When compiling this model, you can assign different losses to each output. You can even assign different weights to each loss – to modulate their contribution to the total training loss.</p> <div class="codehilite"><pre><span></span><code><span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">RMSprop</span><span class="p">(</span><span class="mf">1e-3</span><span class="p">),</span> <span class="n">loss</span><span class="o">=</span><span class="p">[</span> <span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">BinaryCrossentropy</span><span class="p">(</span><span class="n">from_logits</span><span class="o">=</span><span class="kc">True</span><span class="p">),</span> <span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">CategoricalCrossentropy</span><span class="p">(</span><span class="n">from_logits</span><span class="o">=</span><span class="kc">True</span><span class="p">),</span> <span class="p">],</span> <span class="n">loss_weights</span><span class="o">=</span><span class="p">[</span><span class="mf">1.0</span><span class="p">,</span> <span class="mf">0.2</span><span class="p">],</span> <span class="p">)</span> </code></pre></div> <p>Since the output layers have different names, you could also specify the losses and loss weights with the corresponding layer names:</p> <div class="codehilite"><pre><span></span><code><span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">RMSprop</span><span class="p">(</span><span class="mf">1e-3</span><span class="p">),</span> <span class="n">loss</span><span class="o">=</span><span class="p">{</span> <span class="s2">"priority"</span><span class="p">:</span> <span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">BinaryCrossentropy</span><span class="p">(</span><span class="n">from_logits</span><span class="o">=</span><span class="kc">True</span><span class="p">),</span> <span class="s2">"department"</span><span class="p">:</span> <span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">CategoricalCrossentropy</span><span class="p">(</span><span class="n">from_logits</span><span class="o">=</span><span class="kc">True</span><span class="p">),</span> <span class="p">},</span> <span class="n">loss_weights</span><span class="o">=</span><span class="p">{</span><span class="s2">"priority"</span><span class="p">:</span> <span class="mf">1.0</span><span class="p">,</span> <span class="s2">"department"</span><span class="p">:</span> <span class="mf">0.2</span><span class="p">},</span> <span class="p">)</span> </code></pre></div> <p>Train the model by passing lists of NumPy arrays of inputs and targets:</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Dummy input data</span> <span class="n">title_data</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="n">num_words</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">1280</span><span class="p">,</span> <span class="mi">12</span><span class="p">))</span> <span class="n">body_data</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="n">num_words</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">1280</span><span class="p">,</span> <span class="mi">100</span><span class="p">))</span> <span class="n">tags_data</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">1280</span><span class="p">,</span> <span class="n">num_tags</span><span class="p">))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">"float32"</span><span class="p">)</span> <span class="c1"># Dummy target data</span> <span class="n">priority_targets</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">1280</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="n">dept_targets</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">1280</span><span class="p">,</span> <span class="n">num_departments</span><span class="p">))</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span> <span class="p">{</span><span class="s2">"title"</span><span class="p">:</span> <span class="n">title_data</span><span class="p">,</span> <span class="s2">"body"</span><span class="p">:</span> <span class="n">body_data</span><span class="p">,</span> <span class="s2">"tags"</span><span class="p">:</span> <span class="n">tags_data</span><span class="p">},</span> <span class="p">{</span><span class="s2">"priority"</span><span class="p">:</span> <span class="n">priority_targets</span><span class="p">,</span> <span class="s2">"department"</span><span class="p">:</span> <span class="n">dept_targets</span><span class="p">},</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 1/2 40/40 ━━━━━━━━━━━━━━━━━━━━ 3s 57ms/step - loss: 1108.3792 Epoch 2/2 40/40 ━━━━━━━━━━━━━━━━━━━━ 2s 54ms/step - loss: 621.3049 <keras.src.callbacks.history.History at 0x34afc3d90> </code></pre></div> </div> <p>When calling fit with a <code>Dataset</code> object, it should yield either a tuple of lists like <code>([title_data, body_data, tags_data], [priority_targets, dept_targets])</code> or a tuple of dictionaries like <code>({'title': title_data, 'body': body_data, 'tags': tags_data}, {'priority': priority_targets, 'department': dept_targets})</code>.</p> <p>For more detailed explanation, refer to the <a href="/guides/training_with_built_in_methods/">training and evaluation</a> guide.</p> <h3 id="a-toy-resnet-model">A toy ResNet model</h3> <p>In addition to models with multiple inputs and outputs, the functional API makes it easy to manipulate non-linear connectivity topologies – these are models with layers that are not connected sequentially, which the <code>Sequential</code> API cannot handle.</p> <p>A common use case for this is residual connections. Let's build a toy ResNet model for CIFAR10 to demonstrate this:</p> <div class="codehilite"><pre><span></span><code><span class="n">inputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"img"</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">block_1_output</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">MaxPooling2D</span><span class="p">(</span><span class="mi">3</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"same"</span><span class="p">)(</span><span class="n">block_1_output</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"same"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">block_2_output</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">add</span><span class="p">([</span><span class="n">x</span><span class="p">,</span> <span class="n">block_1_output</span><span class="p">])</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"same"</span><span class="p">)(</span><span class="n">block_2_output</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"same"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">block_3_output</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">add</span><span class="p">([</span><span class="n">x</span><span class="p">,</span> <span class="n">block_2_output</span><span class="p">])</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)(</span><span class="n">block_3_output</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">GlobalAveragePooling2D</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">256</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="mf">0.5</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">10</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"toy_resnet"</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">summary</span><span class="p">()</span> </code></pre></div> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold">Model: "toy_resnet"</span> </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace">┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓ ┃<span style="font-weight: bold"> Layer (type) </span>┃<span style="font-weight: bold"> Output Shape </span>┃<span style="font-weight: bold"> Param # </span>┃<span style="font-weight: bold"> Connected to </span>┃ ┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩ │ img (<span style="color: #0087ff; text-decoration-color: #0087ff">InputLayer</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>, <span style="color: #00af00; text-decoration-color: #00af00">3</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ - │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv2d_8 (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv2D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">30</span>, <span style="color: #00af00; text-decoration-color: #00af00">30</span>, │ <span style="color: #00af00; text-decoration-color: #00af00">896</span> │ img[<span style="color: #00af00; text-decoration-color: #00af00">0</span>][<span style="color: #00af00; text-decoration-color: #00af00">0</span>] │ │ │ <span style="color: #00af00; text-decoration-color: #00af00">32</span>) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv2d_9 (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv2D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">28</span>, <span style="color: #00af00; text-decoration-color: #00af00">28</span>, │ <span style="color: #00af00; text-decoration-color: #00af00">18,496</span> │ conv2d_8[<span style="color: #00af00; text-decoration-color: #00af00">0</span>][<span style="color: #00af00; text-decoration-color: #00af00">0</span>] │ │ │ <span style="color: #00af00; text-decoration-color: #00af00">64</span>) │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ max_pooling2d_2 │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">9</span>, <span style="color: #00af00; text-decoration-color: #00af00">9</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ conv2d_9[<span style="color: #00af00; text-decoration-color: #00af00">0</span>][<span style="color: #00af00; text-decoration-color: #00af00">0</span>] │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">MaxPooling2D</span>) │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv2d_10 (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv2D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">9</span>, <span style="color: #00af00; text-decoration-color: #00af00">9</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">36,928</span> │ max_pooling2d_2[<span style="color: #00af00; text-decoration-color: #00af00">…</span> │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv2d_11 (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv2D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">9</span>, <span style="color: #00af00; text-decoration-color: #00af00">9</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">36,928</span> │ conv2d_10[<span style="color: #00af00; text-decoration-color: #00af00">0</span>][<span style="color: #00af00; text-decoration-color: #00af00">0</span>] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add (<span style="color: #0087ff; text-decoration-color: #0087ff">Add</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">9</span>, <span style="color: #00af00; text-decoration-color: #00af00">9</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ conv2d_11[<span style="color: #00af00; text-decoration-color: #00af00">0</span>][<span style="color: #00af00; text-decoration-color: #00af00">0</span>], │ │ │ │ │ max_pooling2d_2[<span style="color: #00af00; text-decoration-color: #00af00">…</span> │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv2d_12 (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv2D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">9</span>, <span style="color: #00af00; text-decoration-color: #00af00">9</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">36,928</span> │ add[<span style="color: #00af00; text-decoration-color: #00af00">0</span>][<span style="color: #00af00; text-decoration-color: #00af00">0</span>] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv2d_13 (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv2D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">9</span>, <span style="color: #00af00; text-decoration-color: #00af00">9</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">36,928</span> │ conv2d_12[<span style="color: #00af00; text-decoration-color: #00af00">0</span>][<span style="color: #00af00; text-decoration-color: #00af00">0</span>] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ add_1 (<span style="color: #0087ff; text-decoration-color: #0087ff">Add</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">9</span>, <span style="color: #00af00; text-decoration-color: #00af00">9</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ conv2d_13[<span style="color: #00af00; text-decoration-color: #00af00">0</span>][<span style="color: #00af00; text-decoration-color: #00af00">0</span>], │ │ │ │ │ add[<span style="color: #00af00; text-decoration-color: #00af00">0</span>][<span style="color: #00af00; text-decoration-color: #00af00">0</span>] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ conv2d_14 (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv2D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">7</span>, <span style="color: #00af00; text-decoration-color: #00af00">7</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">36,928</span> │ add_1[<span style="color: #00af00; text-decoration-color: #00af00">0</span>][<span style="color: #00af00; text-decoration-color: #00af00">0</span>] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ global_average_poo… │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ conv2d_14[<span style="color: #00af00; text-decoration-color: #00af00">0</span>][<span style="color: #00af00; text-decoration-color: #00af00">0</span>] │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">GlobalAveragePool…</span> │ │ │ │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dense_6 (<span style="color: #0087ff; text-decoration-color: #0087ff">Dense</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">256</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">16,640</span> │ global_average_p… │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dropout (<span style="color: #0087ff; text-decoration-color: #0087ff">Dropout</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">256</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ dense_6[<span style="color: #00af00; text-decoration-color: #00af00">0</span>][<span style="color: #00af00; text-decoration-color: #00af00">0</span>] │ ├─────────────────────┼───────────────────┼────────────┼───────────────────┤ │ dense_7 (<span style="color: #0087ff; text-decoration-color: #0087ff">Dense</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">10</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">2,570</span> │ dropout[<span style="color: #00af00; text-decoration-color: #00af00">0</span>][<span style="color: #00af00; text-decoration-color: #00af00">0</span>] │ └─────────────────────┴───────────────────┴────────────┴───────────────────┘ </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Total params: </span><span style="color: #00af00; text-decoration-color: #00af00">223,242</span> (872.04 KB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">223,242</span> (872.04 KB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Non-trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">0</span> (0.00 B) </pre> <p>Plot the model:</p> <div class="codehilite"><pre><span></span><code><span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">plot_model</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="s2">"mini_resnet.png"</span><span class="p">,</span> <span class="n">show_shapes</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> </code></pre></div> <p><img alt="png" src="/img/guides/functional_api/functional_api_51_0.png" /></p> <p>Now train the model:</p> <div class="codehilite"><pre><span></span><code><span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">),</span> <span class="p">(</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">)</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">datasets</span><span class="o">.</span><span class="n">cifar10</span><span class="o">.</span><span class="n">load_data</span><span class="p">()</span> <span class="n">x_train</span> <span class="o">=</span> <span class="n">x_train</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">"float32"</span><span class="p">)</span> <span class="o">/</span> <span class="mf">255.0</span> <span class="n">x_test</span> <span class="o">=</span> <span class="n">x_test</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">"float32"</span><span class="p">)</span> <span class="o">/</span> <span class="mf">255.0</span> <span class="n">y_train</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">to_categorical</span><span class="p">(</span><span class="n">y_train</span><span class="p">,</span> <span class="mi">10</span><span class="p">)</span> <span class="n">y_test</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">to_categorical</span><span class="p">(</span><span class="n">y_test</span><span class="p">,</span> <span class="mi">10</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">RMSprop</span><span class="p">(</span><span class="mf">1e-3</span><span class="p">),</span> <span class="n">loss</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">CategoricalCrossentropy</span><span class="p">(</span><span class="n">from_logits</span><span class="o">=</span><span class="kc">True</span><span class="p">),</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="s2">"acc"</span><span class="p">],</span> <span class="p">)</span> <span class="c1"># We restrict the data to the first 1000 samples so as to limit execution time</span> <span class="c1"># on Colab. Try to train on the entire dataset until convergence!</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span> <span class="n">x_train</span><span class="p">[:</span><span class="mi">1000</span><span class="p">],</span> <span class="n">y_train</span><span class="p">[:</span><span class="mi">1000</span><span class="p">],</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">validation_split</span><span class="o">=</span><span class="mf">0.2</span><span class="p">,</span> <span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 13/13 ━━━━━━━━━━━━━━━━━━━━ 1s 60ms/step - acc: 0.1096 - loss: 2.3053 - val_acc: 0.1150 - val_loss: 2.2973 <keras.src.callbacks.history.History at 0x1758bed40> </code></pre></div> </div> <hr /> <h2 id="shared-layers">Shared layers</h2> <p>Another good use for the functional API are models that use <em>shared layers</em>. Shared layers are layer instances that are reused multiple times in the same model – they learn features that correspond to multiple paths in the graph-of-layers.</p> <p>Shared layers are often used to encode inputs from similar spaces (say, two different pieces of text that feature similar vocabulary). They enable sharing of information across these different inputs, and they make it possible to train such a model on less data. If a given word is seen in one of the inputs, that will benefit the processing of all inputs that pass through the shared layer.</p> <p>To share a layer in the functional API, call the same layer instance multiple times. For instance, here's an <code>Embedding</code> layer shared across two different text inputs:</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Embedding for 1000 unique words mapped to 128-dimensional vectors</span> <span class="n">shared_embedding</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="mi">1000</span><span class="p">,</span> <span class="mi">128</span><span class="p">)</span> <span class="c1"># Variable-length sequence of integers</span> <span class="n">text_input_a</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="kc">None</span><span class="p">,),</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"int32"</span><span class="p">)</span> <span class="c1"># Variable-length sequence of integers</span> <span class="n">text_input_b</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="kc">None</span><span class="p">,),</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"int32"</span><span class="p">)</span> <span class="c1"># Reuse the same layer to encode both inputs</span> <span class="n">encoded_input_a</span> <span class="o">=</span> <span class="n">shared_embedding</span><span class="p">(</span><span class="n">text_input_a</span><span class="p">)</span> <span class="n">encoded_input_b</span> <span class="o">=</span> <span class="n">shared_embedding</span><span class="p">(</span><span class="n">text_input_b</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="extract-and-reuse-nodes-in-the-graph-of-layers">Extract and reuse nodes in the graph of layers</h2> <p>Because the graph of layers you are manipulating is a static data structure, it can be accessed and inspected. And this is how you are able to plot functional models as images.</p> <p>This also means that you can access the activations of intermediate layers ("nodes" in the graph) and reuse them elsewhere – which is very useful for something like feature extraction.</p> <p>Let's look at an example. This is a VGG19 model with weights pretrained on ImageNet:</p> <div class="codehilite"><pre><span></span><code><span class="n">vgg19</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">applications</span><span class="o">.</span><span class="n">VGG19</span><span class="p">()</span> </code></pre></div> <p>And these are the intermediate activations of the model, obtained by querying the graph data structure:</p> <div class="codehilite"><pre><span></span><code><span class="n">features_list</span> <span class="o">=</span> <span class="p">[</span><span class="n">layer</span><span class="o">.</span><span class="n">output</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="n">vgg19</span><span class="o">.</span><span class="n">layers</span><span class="p">]</span> </code></pre></div> <p>Use these features to create a new feature-extraction model that returns the values of the intermediate layer activations:</p> <div class="codehilite"><pre><span></span><code><span class="n">feat_extraction_model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="o">=</span><span class="n">vgg19</span><span class="o">.</span><span class="n">input</span><span class="p">,</span> <span class="n">outputs</span><span class="o">=</span><span class="n">features_list</span><span class="p">)</span> <span class="n">img</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">224</span><span class="p">,</span> <span class="mi">224</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">"float32"</span><span class="p">)</span> <span class="n">extracted_features</span> <span class="o">=</span> <span class="n">feat_extraction_model</span><span class="p">(</span><span class="n">img</span><span class="p">)</span> </code></pre></div> <p>This comes in handy for tasks like <a href="https://keras.io/examples/generative/neural_style_transfer/">neural style transfer</a>, among other things.</p> <hr /> <h2 id="extend-the-api-using-custom-layers">Extend the API using custom layers</h2> <p><code>keras</code> includes a wide range of built-in layers, for example:</p> <ul> <li>Convolutional layers: <code>Conv1D</code>, <code>Conv2D</code>, <code>Conv3D</code>, <code>Conv2DTranspose</code></li> <li>Pooling layers: <code>MaxPooling1D</code>, <code>MaxPooling2D</code>, <code>MaxPooling3D</code>, <code>AveragePooling1D</code></li> <li>RNN layers: <code>GRU</code>, <code>LSTM</code>, <code>ConvLSTM2D</code></li> <li><code>BatchNormalization</code>, <code>Dropout</code>, <code>Embedding</code>, etc.</li> </ul> <p>But if you don't find what you need, it's easy to extend the API by creating your own layers. All layers subclass the <code>Layer</code> class and implement:</p> <ul> <li><code>call</code> method, that specifies the computation done by the layer.</li> <li><code>build</code> method, that creates the weights of the layer (this is just a style convention since you can create weights in <code>__init__</code>, as well).</li> </ul> <p>To learn more about creating layers from scratch, read <a href="/guides/making_new_layers_and_models_via_subclassing">custom layers and models</a> guide.</p> <p>The following is a basic implementation of <a href="/api/layers/core_layers/dense#dense-class"><code>keras.layers.Dense</code></a>:</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">CustomDense</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">units</span><span class="o">=</span><span class="mi">32</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">units</span> <span class="o">=</span> <span class="n">units</span> <span class="k">def</span> <span class="nf">build</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_shape</span><span class="p">):</span> <span class="bp">self</span><span class="o">.</span><span class="n">w</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">add_weight</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">input_shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">units</span><span class="p">),</span> <span class="n">initializer</span><span class="o">=</span><span class="s2">"random_normal"</span><span class="p">,</span> <span class="n">trainable</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">b</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">add_weight</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">units</span><span class="p">,),</span> <span class="n">initializer</span><span class="o">=</span><span class="s2">"random_normal"</span><span class="p">,</span> <span class="n">trainable</span><span class="o">=</span><span class="kc">True</span> <span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="k">return</span> <span class="n">ops</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">w</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">b</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">((</span><span class="mi">4</span><span class="p">,))</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">CustomDense</span><span class="p">(</span><span class="mi">10</span><span class="p">)(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="p">)</span> </code></pre></div> <p>For serialization support in your custom layer, define a <code>get_config()</code> method that returns the constructor arguments of the layer instance:</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">CustomDense</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">units</span><span class="o">=</span><span class="mi">32</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">units</span> <span class="o">=</span> <span class="n">units</span> <span class="k">def</span> <span class="nf">build</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_shape</span><span class="p">):</span> <span class="bp">self</span><span class="o">.</span><span class="n">w</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">add_weight</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">input_shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">units</span><span class="p">),</span> <span class="n">initializer</span><span class="o">=</span><span class="s2">"random_normal"</span><span class="p">,</span> <span class="n">trainable</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">b</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">add_weight</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">units</span><span class="p">,),</span> <span class="n">initializer</span><span class="o">=</span><span class="s2">"random_normal"</span><span class="p">,</span> <span class="n">trainable</span><span class="o">=</span><span class="kc">True</span> <span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="k">return</span> <span class="n">ops</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">w</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">b</span> <span class="k">def</span> <span class="nf">get_config</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="k">return</span> <span class="p">{</span><span class="s2">"units"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">units</span><span class="p">}</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">((</span><span class="mi">4</span><span class="p">,))</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">CustomDense</span><span class="p">(</span><span class="mi">10</span><span class="p">)(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="p">)</span> <span class="n">config</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">get_config</span><span class="p">()</span> <span class="n">new_model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="o">.</span><span class="n">from_config</span><span class="p">(</span><span class="n">config</span><span class="p">,</span> <span class="n">custom_objects</span><span class="o">=</span><span class="p">{</span><span class="s2">"CustomDense"</span><span class="p">:</span> <span class="n">CustomDense</span><span class="p">})</span> </code></pre></div> <p>Optionally, implement the class method <code>from_config(cls, config)</code> which is used when recreating a layer instance given its config dictionary. The default implementation of <code>from_config</code> is:</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">from_config</span><span class="p">(</span><span class="bp">cls</span><span class="p">,</span> <span class="n">config</span><span class="p">):</span> <span class="k">return</span> <span class="bp">cls</span><span class="p">(</span><span class="o">**</span><span class="n">config</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="when-to-use-the-functional-api">When to use the functional API</h2> <p>Should you use the Keras functional API to create a new model, or just subclass the <code>Model</code> class directly? In general, the functional API is higher-level, easier and safer, and has a number of features that subclassed models do not support.</p> <p>However, model subclassing provides greater flexibility when building models that are not easily expressible as directed acyclic graphs of layers. For example, you could not implement a Tree-RNN with the functional API and would have to subclass <code>Model</code> directly.</p> <p>For an in-depth look at the differences between the functional API and model subclassing, read <a href="https://blog.tensorflow.org/2019/01/what-are-symbolic-and-imperative-apis.html">What are Symbolic and Imperative APIs in TensorFlow 2.0?</a>.</p> <h3 id="functional-api-strengths">Functional API strengths:</h3> <p>The following properties are also true for Sequential models (which are also data structures), but are not true for subclassed models (which are Python bytecode, not data structures).</p> <h4 id="less-verbose">Less verbose</h4> <p>There is no <code>super().__init__(...)</code>, no <code>def call(self, ...):</code>, etc.</p> <p>Compare:</p> <div class="codehilite"><pre><span></span><code><span class="n">inputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">32</span><span class="p">,))</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s1">'relu'</span><span class="p">)(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">10</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">mlp</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="p">)</span> </code></pre></div> <p>With the subclassed version:</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">MLP</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense_1</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s1">'relu'</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense_2</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">10</span><span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense_1</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense_2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># Instantiate the model.</span> <span class="n">mlp</span> <span class="o">=</span> <span class="n">MLP</span><span class="p">()</span> <span class="c1"># Necessary to create the model's state.</span> <span class="c1"># The model doesn't have a state until it's called at least once.</span> <span class="n">_</span> <span class="o">=</span> <span class="n">mlp</span><span class="p">(</span><span class="n">ops</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">32</span><span class="p">)))</span> </code></pre></div> <h4 id="model-validation-while-defining-its-connectivity-graph">Model validation while defining its connectivity graph</h4> <p>In the functional API, the input specification (shape and dtype) is created in advance (using <code>Input</code>). Every time you call a layer, the layer checks that the specification passed to it matches its assumptions, and it will raise a helpful error message if not.</p> <p>This guarantees that any model you can build with the functional API will run. All debugging – other than convergence-related debugging – happens statically during the model construction and not at execution time. This is similar to type checking in a compiler.</p> <h4 id="a-functional-model-is-plottable-and-inspectable">A functional model is plottable and inspectable</h4> <p>You can plot the model as a graph, and you can easily access intermediate nodes in this graph. For example, to extract and reuse the activations of intermediate layers (as seen in a previous example):</p> <div class="codehilite"><pre><span></span><code><span class="n">features_list</span> <span class="o">=</span> <span class="p">[</span><span class="n">layer</span><span class="o">.</span><span class="n">output</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="n">vgg19</span><span class="o">.</span><span class="n">layers</span><span class="p">]</span> <span class="n">feat_extraction_model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="o">=</span><span class="n">vgg19</span><span class="o">.</span><span class="n">input</span><span class="p">,</span> <span class="n">outputs</span><span class="o">=</span><span class="n">features_list</span><span class="p">)</span> </code></pre></div> <h4 id="a-functional-model-can-be-serialized-or-cloned">A functional model can be serialized or cloned</h4> <p>Because a functional model is a data structure rather than a piece of code, it is safely serializable and can be saved as a single file that allows you to recreate the exact same model without having access to any of the original code. See the <a href="/guides/serialization_and_saving/">serialization & saving guide</a>.</p> <p>To serialize a subclassed model, it is necessary for the implementer to specify a <code>get_config()</code> and <code>from_config()</code> method at the model level.</p> <h3 id="functional-api-weakness">Functional API weakness:</h3> <h4 id="it-does-not-support-dynamic-architectures">It does not support dynamic architectures</h4> <p>The functional API treats models as DAGs of layers. This is true for most deep learning architectures, but not all – for example, recursive networks or Tree RNNs do not follow this assumption and cannot be implemented in the functional API.</p> <hr /> <h2 id="mixandmatch-api-styles">Mix-and-match API styles</h2> <p>Choosing between the functional API or Model subclassing isn't a binary decision that restricts you into one category of models. All models in the <code>keras</code> API can interact with each other, whether they're <code>Sequential</code> models, functional models, or subclassed models that are written from scratch.</p> <p>You can always use a functional model or <code>Sequential</code> model as part of a subclassed model or layer:</p> <div class="codehilite"><pre><span></span><code><span class="n">units</span> <span class="o">=</span> <span class="mi">32</span> <span class="n">timesteps</span> <span class="o">=</span> <span class="mi">10</span> <span class="n">input_dim</span> <span class="o">=</span> <span class="mi">5</span> <span class="c1"># Define a Functional model</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">((</span><span class="kc">None</span><span class="p">,</span> <span class="n">units</span><span class="p">))</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">GlobalAveragePooling1D</span><span class="p">()(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">1</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="p">)</span> <span class="k">class</span> <span class="nc">CustomRNN</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">units</span> <span class="o">=</span> <span class="n">units</span> <span class="bp">self</span><span class="o">.</span><span class="n">projection_1</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">units</span><span class="o">=</span><span class="n">units</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"tanh"</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">projection_2</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">units</span><span class="o">=</span><span class="n">units</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"tanh"</span><span class="p">)</span> <span class="c1"># Our previously-defined Functional model</span> <span class="bp">self</span><span class="o">.</span><span class="n">classifier</span> <span class="o">=</span> <span class="n">model</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="n">outputs</span> <span class="o">=</span> <span class="p">[]</span> <span class="n">state</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">inputs</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">units</span><span class="p">))</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">inputs</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]):</span> <span class="n">x</span> <span class="o">=</span> <span class="n">inputs</span><span class="p">[:,</span> <span class="n">t</span><span class="p">,</span> <span class="p">:]</span> <span class="n">h</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">projection_1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">y</span> <span class="o">=</span> <span class="n">h</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">projection_2</span><span class="p">(</span><span class="n">state</span><span class="p">)</span> <span class="n">state</span> <span class="o">=</span> <span class="n">y</span> <span class="n">outputs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">y</span><span class="p">)</span> <span class="n">features</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">outputs</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="n">features</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">classifier</span><span class="p">(</span><span class="n">features</span><span class="p">)</span> <span class="n">rnn_model</span> <span class="o">=</span> <span class="n">CustomRNN</span><span class="p">()</span> <span class="n">_</span> <span class="o">=</span> <span class="n">rnn_model</span><span class="p">(</span><span class="n">ops</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="n">timesteps</span><span class="p">,</span> <span class="n">input_dim</span><span class="p">)))</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>(1, 10, 32) (1, 10, 32) </code></pre></div> </div> <p>You can use any subclassed layer or model in the functional API as long as it implements a <code>call</code> method that follows one of the following patterns:</p> <ul> <li><code>call(self, inputs, **kwargs)</code> – Where <code>inputs</code> is a tensor or a nested structure of tensors (e.g. a list of tensors), and where <code>**kwargs</code> are non-tensor arguments (non-inputs).</li> <li><code>call(self, inputs, training=None, **kwargs)</code> – Where <code>training</code> is a boolean indicating whether the layer should behave in training mode and inference mode.</li> <li><code>call(self, inputs, mask=None, **kwargs)</code> – Where <code>mask</code> is a boolean mask tensor (useful for RNNs, for instance).</li> <li><code>call(self, inputs, training=None, mask=None, **kwargs)</code> – Of course, you can have both masking and training-specific behavior at the same time.</li> </ul> <p>Additionally, if you implement the <code>get_config</code> method on your custom Layer or model, the functional models you create will still be serializable and cloneable.</p> <p>Here's a quick example of a custom RNN, written from scratch, being used in a functional model:</p> <div class="codehilite"><pre><span></span><code><span class="n">units</span> <span class="o">=</span> <span class="mi">32</span> <span class="n">timesteps</span> <span class="o">=</span> <span class="mi">10</span> <span class="n">input_dim</span> <span class="o">=</span> <span class="mi">5</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="mi">16</span> <span class="k">class</span> <span class="nc">CustomRNN</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">units</span> <span class="o">=</span> <span class="n">units</span> <span class="bp">self</span><span class="o">.</span><span class="n">projection_1</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">units</span><span class="o">=</span><span class="n">units</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"tanh"</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">projection_2</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">units</span><span class="o">=</span><span class="n">units</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"tanh"</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">classifier</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="n">outputs</span> <span class="o">=</span> <span class="p">[]</span> <span class="n">state</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">inputs</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">units</span><span class="p">))</span> <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">inputs</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]):</span> <span class="n">x</span> <span class="o">=</span> <span class="n">inputs</span><span class="p">[:,</span> <span class="n">t</span><span class="p">,</span> <span class="p">:]</span> <span class="n">h</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">projection_1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">y</span> <span class="o">=</span> <span class="n">h</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">projection_2</span><span class="p">(</span><span class="n">state</span><span class="p">)</span> <span class="n">state</span> <span class="o">=</span> <span class="n">y</span> <span class="n">outputs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">y</span><span class="p">)</span> <span class="n">features</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">outputs</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">classifier</span><span class="p">(</span><span class="n">features</span><span class="p">)</span> <span class="c1"># Note that you specify a static batch size for the inputs with the `batch_shape`</span> <span class="c1"># arg, because the inner computation of `CustomRNN` requires a static batch size</span> <span class="c1"># (when you create the `state` zeros tensor).</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">batch_shape</span><span class="o">=</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">timesteps</span><span class="p">,</span> <span class="n">input_dim</span><span class="p">))</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv1D</span><span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="mi">3</span><span class="p">)(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">CustomRNN</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="p">)</span> <span class="n">rnn_model</span> <span class="o">=</span> <span class="n">CustomRNN</span><span class="p">()</span> <span class="n">_</span> <span class="o">=</span> <span class="n">rnn_model</span><span class="p">(</span><span class="n">ops</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">5</span><span class="p">)))</span> </code></pre></div> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#the-functional-api'>The Functional API</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#setup'>Setup</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#training-evaluation-and-inference'>Training, evaluation, and inference</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#save-and-serialize'>Save and serialize</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#use-the-same-graph-of-layers-to-define-multiple-models'>Use the same graph of layers to define multiple models</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#all-models-are-callable-just-like-layers'>All models are callable, just like layers</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#manipulate-complex-graph-topologies'>Manipulate complex graph topologies</a> </div> <div class='k-outline-depth-3'> <a href='#models-with-multiple-inputs-and-outputs'>Models with multiple inputs and outputs</a> </div> <div class='k-outline-depth-3'> <a href='#a-toy-resnet-model'>A toy ResNet model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#shared-layers'>Shared layers</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#extract-and-reuse-nodes-in-the-graph-of-layers'>Extract and reuse nodes in the graph of layers</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#extend-the-api-using-custom-layers'>Extend the API using custom layers</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#when-to-use-the-functional-api'>When to use the functional API</a> </div> <div class='k-outline-depth-3'> <a href='#functional-api-strengths'>Functional API strengths:</a> </div> <div class='k-outline-depth-3'> <a href='#functional-api-weakness'>Functional API weakness:</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#mixandmatch-api-styles'>Mix-and-match API styles</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>