CINXE.COM
Migrating Keras 2 code to multi-backend Keras 3
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/guides/migrating_to_keras_3/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Migrating Keras 2 code to multi-backend Keras 3"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Migrating Keras 2 code to multi-backend Keras 3"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Migrating Keras 2 code to multi-backend Keras 3</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link active" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-sublink" href="/guides/functional_api/">The Functional API</a> <a class="nav-sublink" href="/guides/sequential_model/">The Sequential model</a> <a class="nav-sublink" href="/guides/making_new_layers_and_models_via_subclassing/">Making new layers & models via subclassing</a> <a class="nav-sublink" href="/guides/training_with_built_in_methods/">Training & evaluation with the built-in methods</a> <a class="nav-sublink" href="/guides/custom_train_step_in_jax/">Customizing `fit()` with JAX</a> <a class="nav-sublink" href="/guides/custom_train_step_in_tensorflow/">Customizing `fit()` with TensorFlow</a> <a class="nav-sublink" href="/guides/custom_train_step_in_torch/">Customizing `fit()` with PyTorch</a> <a class="nav-sublink" href="/guides/writing_a_custom_training_loop_in_jax/">Writing a custom training loop in JAX</a> <a class="nav-sublink" href="/guides/writing_a_custom_training_loop_in_tensorflow/">Writing a custom training loop in TensorFlow</a> <a class="nav-sublink" href="/guides/writing_a_custom_training_loop_in_torch/">Writing a custom training loop in PyTorch</a> <a class="nav-sublink" href="/guides/serialization_and_saving/">Serialization & saving</a> <a class="nav-sublink" href="/guides/customizing_saving_and_serialization/">Customizing saving & serialization</a> <a class="nav-sublink" href="/guides/writing_your_own_callbacks/">Writing your own callbacks</a> <a class="nav-sublink" href="/guides/transfer_learning/">Transfer learning & fine-tuning</a> <a class="nav-sublink" href="/guides/distributed_training_with_jax/">Distributed training with JAX</a> <a class="nav-sublink" href="/guides/distributed_training_with_tensorflow/">Distributed training with TensorFlow</a> <a class="nav-sublink" href="/guides/distributed_training_with_torch/">Distributed training with PyTorch</a> <a class="nav-sublink" href="/guides/distribution/">Distributed training with Keras 3</a> <a class="nav-sublink active" href="/guides/migrating_to_keras_3/">Migrating Keras 2 code to Keras 3</a> <a class="nav-sublink" href="/guides/keras_tuner/">Hyperparameter Tuning</a> <a class="nav-sublink" href="/guides/keras_cv/">KerasCV</a> <a class="nav-sublink" href="/guides/keras_nlp/">KerasNLP</a> <a class="nav-sublink" href="/guides/keras_hub/">KerasHub</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparameter Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> <a class="nav-link" href="/keras_cv/" role="tab" aria-selected="">KerasCV: Computer Vision Workflows</a> <a class="nav-link" href="/keras_nlp/" role="tab" aria-selected="">KerasNLP: Natural Language Workflows</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/guides/'>Developer guides</a> / Migrating Keras 2 code to multi-backend Keras 3 </div> <div class='k-content'> <h1 id="migrating-keras-2-code-to-multibackend-keras-3">Migrating Keras 2 code to multi-backend Keras 3</h1> <p><strong>Author:</strong> <a href="https://github.com/divyashreepathihalli">Divyashree Sreepathihalli</a><br> <strong>Date created:</strong> 2023/10/23<br> <strong>Last modified:</strong> 2023/10/30<br> <strong>Description:</strong> Instructions & troubleshooting for migrating your Keras 2 code to multi-backend Keras 3.</p> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/guides/ipynb/migrating_to_keras_3.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/guides/migrating_to_keras_3.py"><strong>GitHub source</strong></a></p> <p>This guide will help you migrate TensorFlow-only Keras 2 code to multi-backend Keras 3 code. The overhead for the migration is minimal. Once you have migrated, you can run Keras workflows on top of either JAX, TensorFlow, or PyTorch.</p> <p>This guide has two parts:</p> <ol> <li>Migrating your legacy Keras 2 code to Keras 3, running on top of the TensorFlow backend. This is generally very easy, though there are minor issues to be mindful of, that we will go over in detail.</li> <li>Further migrating your Keras 3 + TensorFlow code to multi-backend Keras 3, so that it can run on JAX and PyTorch.</li> </ol> <p>Let's get started.</p> <hr /> <h2 id="setup">Setup</h2> <p>First, lets install <code>keras-nightly</code>.</p> <p>This example uses the TensorFlow backend (<code>os.environ["KERAS_BACKEND"] = "tensorflow"</code>). After you've migrated your code, you can change the <code>"tensorflow"</code> string to <code>"jax"</code> or <code>"torch"</code> and click "Restart runtime" in Colab, and your code will run on the JAX or PyTorch backend.</p> <div class="codehilite"><pre><span></span><code><span class="err">!</span><span class="n">pip</span> <span class="n">install</span> <span class="o">-</span><span class="n">q</span> <span class="n">keras</span><span class="o">-</span><span class="n">nightly</span> </code></pre></div> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">os</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s2">"KERAS_BACKEND"</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"tensorflow"</span> <span class="kn">import</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="nn">tf</span> <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> [[34;49mnotice[1;39;49m][39;49m A new release of pip is available: [31;49m23.3.1[39;49m -> [32;49m24.0 [[34;49mnotice[1;39;49m][39;49m To update, run: [32;49mpip install --upgrade pip </code></pre></div> </div> <hr /> <h2 id="going-from-keras-2-to-keras-3-with-the-tensorflow-backend">Going from Keras 2 to Keras 3 with the TensorFlow backend</h2> <p>First, replace your imports:</p> <ol> <li>Replace <code>from tensorflow import keras</code> to <code>import keras</code></li> <li>Replace <code>from tensorflow.keras import xyz</code> (e.g. <code>from tensorflow.keras import layers</code>) to <code>from keras import xyz</code> (e.g. <code>from keras import layers</code>)</li> <li>Replace <a href="https://www.tensorflow.org/api_docs/python/tf/keras/*"><code>tf.keras.*</code></a> to <code>keras.*</code></li> </ol> <p>Next, start running your tests. Most of the time, your code will execute on Keras 3 just fine. All issues you might encounter are detailed below, with their fixes.</p> <h3 id="jitcompile-is-set-to-true-by-default-on-gpu"><code>jit_compile</code> is set to <code>True</code> by default on GPU.</h3> <p>The default value of the <code>jit_compile</code> argument to the <code>Model</code> constructor has been set to <code>True</code> on GPU in Keras 3. This means that models will be compiled with Just-In-Time (JIT) compilation by default on GPU.</p> <p>JIT compilation can improve the performance of some models. However, it may not work with all TensorFlow operations. If you are using a custom model or layer and you see an XLA-related error, you may need to set the <code>jit_compile</code> argument to <code>False</code>. Here is a list of <a href="https://www.tensorflow.org/xla/known_issues">known issues</a> encountered when using XLA with TensorFlow. In addition to these issues, there are some ops that are not supported by XLA.</p> <p>The error message you could encounter would be as follows:</p> <div class="codehilite"><pre><span></span><code>Detected unsupported operations when trying to compile graph __inference_one_step_on_data_125[] on XLA_GPU_JIT </code></pre></div> <p>For example, the following snippet of code will reproduce the above error:</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">MyModel</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="n">string_input</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">strings</span><span class="o">.</span><span class="n">as_string</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="k">return</span> <span class="n">tf</span><span class="o">.</span><span class="n">strings</span><span class="o">.</span><span class="n">to_number</span><span class="p">(</span><span class="n">string_input</span><span class="p">)</span> <span class="n">subclass_model</span> <span class="o">=</span> <span class="n">MyModel</span><span class="p">()</span> <span class="n">x_train</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">],</span> <span class="p">[</span><span class="mi">4</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">6</span><span class="p">]])</span> <span class="n">subclass_model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">optimizer</span><span class="o">=</span><span class="s2">"sgd"</span><span class="p">,</span> <span class="n">loss</span><span class="o">=</span><span class="s2">"mse"</span><span class="p">)</span> <span class="n">subclass_model</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">x_train</span><span class="p">)</span> </code></pre></div> <p><strong>How to fix it:</strong> set <code>jit_compile=False</code> in <code>model.compile(..., jit_compile=False)</code>, or set the <code>jit_compile</code> attribute to <code>False</code>, like this:</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">MyModel</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="c1"># tf.strings ops aren't support by XLA</span> <span class="n">string_input</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">strings</span><span class="o">.</span><span class="n">as_string</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="k">return</span> <span class="n">tf</span><span class="o">.</span><span class="n">strings</span><span class="o">.</span><span class="n">to_number</span><span class="p">(</span><span class="n">string_input</span><span class="p">)</span> <span class="n">subclass_model</span> <span class="o">=</span> <span class="n">MyModel</span><span class="p">()</span> <span class="n">x_train</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">],</span> <span class="p">[</span><span class="mi">4</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">6</span><span class="p">]])</span> <span class="n">subclass_model</span><span class="o">.</span><span class="n">jit_compile</span> <span class="o">=</span> <span class="kc">False</span> <span class="n">subclass_model</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">x_train</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 51ms/step array([[1., 2., 3.], [4., 5., 6.]], dtype=float32) </code></pre></div> </div> <h3 id="saving-a-model-in-the-tf-savedmodel-format">Saving a model in the TF SavedModel format</h3> <p>Saving to the TF SavedModel format via <code>model.save()</code> is no longer supported in Keras 3.</p> <p>The error message you could encounter would be as follows:</p> <div class="codehilite"><pre><span></span><code>>>> model.save("mymodel") ValueError: Invalid filepath extension for saving. Please add either a `.keras` extension for the native Keras format (recommended) or a `.h5` extension. Use `model.export(filepath)` if you want to export a SavedModel for use with TFLite/TFServing/etc. Received: filepath=saved_model. </code></pre></div> <p>The following snippet of code will reproduce the above error:</p> <div class="codehilite"><pre><span></span><code><span class="n">sequential_model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">([</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span> <span class="p">])</span> <span class="n">sequential_model</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="s2">"saved_model"</span><span class="p">)</span> </code></pre></div> <p><strong>How to fix it:</strong> use <code>model.export(filepath)</code> instead of <code>model.save(filepath)</code></p> <div class="codehilite"><pre><span></span><code><span class="n">sequential_model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">([</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">2</span><span class="p">)])</span> <span class="n">sequential_model</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">5</span><span class="p">))</span> <span class="n">sequential_model</span><span class="o">.</span><span class="n">export</span><span class="p">(</span><span class="s2">"saved_model"</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>INFO:tensorflow:Assets written to: saved_model/assets INFO:tensorflow:Assets written to: saved_model/assets Saved artifact at 'saved_model'. The following endpoints are available: </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>* Endpoint 'serve' args_0 (POSITIONAL_ONLY): TensorSpec(shape=(3, 5), dtype=tf.float32, name='keras_tensor') Output Type: TensorSpec(shape=(3, 2), dtype=tf.float32, name=None) Captures: 14428321600: TensorSpec(shape=(), dtype=tf.resource, name=None) 14439128528: TensorSpec(shape=(), dtype=tf.resource, name=None) </code></pre></div> </div> <h3 id="loading-a-tf-savedmodel">Loading a TF SavedModel</h3> <p>Loading a TF SavedModel file via <code>keras.models.load_model()</code> is no longer supported If you try to use <code>keras.models.load_model()</code> with a TF SavedModel, you will get the following error:</p> <div class="codehilite"><pre><span></span><code><span class="ne">ValueError</span><span class="p">:</span> <span class="n">File</span> <span class="nb">format</span> <span class="ow">not</span> <span class="n">supported</span><span class="p">:</span> <span class="n">filepath</span><span class="o">=</span><span class="n">saved_model</span><span class="o">.</span> <span class="n">Keras</span> <span class="mi">3</span> <span class="n">only</span> <span class="n">supports</span> <span class="n">V3</span> <span class="err">`</span><span class="o">.</span><span class="n">keras</span><span class="err">`</span> <span class="n">files</span> <span class="ow">and</span> <span class="n">legacy</span> <span class="n">H5</span> <span class="nb">format</span> <span class="n">files</span> <span class="p">(</span><span class="err">`</span><span class="o">.</span><span class="n">h5</span><span class="err">`</span> <span class="n">extension</span><span class="p">)</span><span class="o">.</span> <span class="n">Note</span> <span class="n">that</span> <span class="n">the</span> <span class="n">legacy</span> <span class="n">SavedModel</span> <span class="nb">format</span> <span class="ow">is</span> <span class="ow">not</span> <span class="n">supported</span> <span class="n">by</span> <span class="err">`</span><span class="n">load_model</span><span class="p">()</span><span class="err">`</span> <span class="ow">in</span> <span class="n">Keras</span> <span class="mf">3.</span> <span class="n">In</span> <span class="n">order</span> <span class="n">to</span> <span class="n">reload</span> <span class="n">a</span> <span class="n">TensorFlow</span> <span class="n">SavedModel</span> <span class="k">as</span> <span class="n">an</span> <span class="n">inference</span><span class="o">-</span><span class="n">only</span> <span class="n">layer</span> <span class="ow">in</span> <span class="n">Keras</span> <span class="mi">3</span><span class="p">,</span> <span class="n">use</span> <span class="err">`</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">TFSMLayer</span><span class="p">(</span><span class="n">saved_model</span><span class="p">,</span> <span class="n">call_endpoint</span><span class="o">=</span><span class="s1">'serving_default'</span><span class="p">)</span><span class="err">`</span> <span class="p">(</span><span class="n">note</span> <span class="n">that</span> <span class="n">your</span> <span class="err">`</span><span class="n">call_endpoint</span><span class="err">`</span> <span class="n">might</span> <span class="n">have</span> <span class="n">a</span> <span class="n">different</span> <span class="n">name</span><span class="p">)</span><span class="o">.</span> </code></pre></div> <p>The following snippet of code will reproduce the above error:</p> <div class="codehilite"><pre><span></span><code><span class="n">keras</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">load_model</span><span class="p">(</span><span class="s2">"saved_model"</span><span class="p">)</span> </code></pre></div> <p><strong>How to fix it:</strong> Use <code>keras.layers.TFSMLayer(filepath, call_endpoint="serving_default")</code> to reload a TF SavedModel as a Keras layer. This is not limited to SavedModels that originate from Keras – it will work with any SavedModel, e.g. TF-Hub models.</p> <div class="codehilite"><pre><span></span><code><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">TFSMLayer</span><span class="p">(</span><span class="s2">"saved_model"</span><span class="p">,</span> <span class="n">call_endpoint</span><span class="o">=</span><span class="s2">"serving_default"</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code><TFSMLayer name=tfsm_layer, built=True> </code></pre></div> </div> <h3 id="using-deeply-nested-inputs-in-functional-models">Using deeply nested inputs in Functional Models</h3> <p><code>Model()</code> can no longer be passed deeply nested inputs/outputs (nested more than 1 level deep, e.g. lists of lists of tensors).</p> <p>You would encounter errors as follows:</p> <div class="codehilite"><pre><span></span><code>ValueError: When providing `inputs` as a dict, all values in the dict must be KerasTensors. Received: inputs={'foo': <KerasTensor shape=(None, 1), dtype=float32, sparse=None, name=foo>, 'bar': {'baz': <KerasTensor shape=(None, 1), dtype=float32, sparse=None, name=bar>}} including invalid value {'baz': <KerasTensor shape=(None, 1), dtype=float32, sparse=None, name=bar>} of type <class 'dict'> </code></pre></div> <p>The following snippet of code will reproduce the above error:</p> <div class="codehilite"><pre><span></span><code><span class="n">inputs</span> <span class="o">=</span> <span class="p">{</span> <span class="s2">"foo"</span><span class="p">:</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"foo"</span><span class="p">),</span> <span class="s2">"bar"</span><span class="p">:</span> <span class="p">{</span> <span class="s2">"baz"</span><span class="p">:</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"bar"</span><span class="p">),</span> <span class="p">},</span> <span class="p">}</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">inputs</span><span class="p">[</span><span class="s2">"foo"</span><span class="p">]</span> <span class="o">+</span> <span class="n">inputs</span><span class="p">[</span><span class="s2">"bar"</span><span class="p">][</span><span class="s2">"baz"</span><span class="p">]</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="p">)</span> </code></pre></div> <p><strong>How to fix it:</strong> replace nested input with either dicts, lists, and tuples of input tensors.</p> <div class="codehilite"><pre><span></span><code><span class="n">inputs</span> <span class="o">=</span> <span class="p">{</span> <span class="s2">"foo"</span><span class="p">:</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"foo"</span><span class="p">),</span> <span class="s2">"bar"</span><span class="p">:</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"bar"</span><span class="p">),</span> <span class="p">}</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">inputs</span><span class="p">[</span><span class="s2">"foo"</span><span class="p">]</span> <span class="o">+</span> <span class="n">inputs</span><span class="p">[</span><span class="s2">"bar"</span><span class="p">]</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code><Functional name=functional_2, built=True> </code></pre></div> </div> <h3 id="tf-autograph">TF autograph</h3> <p>In Keras 2, TF autograph is enabled by default on the <code>call()</code> method of custom layers. In Keras 3, it is not. This means you may have to use cond ops if you're using control flow, or alternatively you can decorate your <code>call()</code> method with <code>@tf.function</code>.</p> <p>You would encounter an error as follows:</p> <div class="codehilite"><pre><span></span><code>OperatorNotAllowedInGraphError: Exception encountered when calling MyCustomLayer.call(). Using a symbolic [`tf.Tensor`](https://www.tensorflow.org/api_docs/python/tf/Tensor) as a Python `bool` is not allowed. You can attempt the following resolutions to the problem: If you are running in Graph mode, use Eager execution mode or decorate this function with @tf.function. If you are using AutoGraph, you can try decorating this function with @tf.function. If that does not work, then you may be using an unsupported feature or your source code may not be visible to AutoGraph. Here is a [link for more information](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/ref erence/limitations.md#access-to-source-code). </code></pre></div> <p>The following snippet of code will reproduce the above error:</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">MyCustomLayer</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="k">if</span> <span class="n">tf</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(())</span> <span class="o">></span> <span class="mf">0.5</span><span class="p">:</span> <span class="k">return</span> <span class="n">inputs</span> <span class="o">*</span> <span class="mi">2</span> <span class="k">else</span><span class="p">:</span> <span class="k">return</span> <span class="n">inputs</span> <span class="o">/</span> <span class="mi">2</span> <span class="n">layer</span> <span class="o">=</span> <span class="n">MyCustomLayer</span><span class="p">()</span> <span class="n">data</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">[</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">])</span> <span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">Sequential</span><span class="p">([</span><span class="n">layer</span><span class="p">])</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">optimizer</span><span class="o">=</span><span class="s2">"adam"</span><span class="p">,</span> <span class="n">loss</span><span class="o">=</span><span class="s2">"mse"</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> </code></pre></div> <p><strong>How to fix it:</strong> decorate your <code>call()</code> method with <code>@tf.function</code></p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">MyCustomLayer</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="nd">@tf</span><span class="o">.</span><span class="n">function</span><span class="p">()</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="k">if</span> <span class="n">tf</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(())</span> <span class="o">></span> <span class="mf">0.5</span><span class="p">:</span> <span class="k">return</span> <span class="n">inputs</span> <span class="o">*</span> <span class="mi">2</span> <span class="k">else</span><span class="p">:</span> <span class="k">return</span> <span class="n">inputs</span> <span class="o">/</span> <span class="mi">2</span> <span class="n">layer</span> <span class="o">=</span> <span class="n">MyCustomLayer</span><span class="p">()</span> <span class="n">data</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">[</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">])</span> <span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">Sequential</span><span class="p">([</span><span class="n">layer</span><span class="p">])</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">optimizer</span><span class="o">=</span><span class="s2">"adam"</span><span class="p">,</span> <span class="n">loss</span><span class="o">=</span><span class="s2">"mse"</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 43ms/step array([[0.59727275, 1.9986179 , 1.5514829 ], [0.56239295, 1.6529864 , 0.33085832], [0.67086476, 1.5208522 , 1.99276 ]], dtype=float32) </code></pre></div> </div> <h3 id="calling-tf-ops-with-a-kerastensor">Calling TF ops with a <code>KerasTensor</code></h3> <p>Using a TF op on a Keras tensor during functional model construction is disallowed: "A KerasTensor cannot be used as input to a TensorFlow function".</p> <p>The error you would encounter would be as follows:</p> <div class="codehilite"><pre><span></span><code>ValueError: A KerasTensor cannot be used as input to a TensorFlow function. A KerasTensor is a symbolic placeholder for a shape and dtype, used when constructing Keras Functional models or Keras Functions. You can only use it as input to a Keras layer or a Keras operation (from the namespaces `keras.layers` and `keras.operations`). </code></pre></div> <p>The following snippet of code will reproduce the error:</p> <div class="codehilite"><pre><span></span><code><span class="nb">input</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Input</span><span class="p">([</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">])</span> <span class="n">tf</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="nb">input</span><span class="p">)</span> </code></pre></div> <p><strong>How to fix it:</strong> use an equivalent op from <code>keras.ops</code>.</p> <div class="codehilite"><pre><span></span><code><span class="nb">input</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Input</span><span class="p">([</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">])</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="nb">input</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code><KerasTensor shape=(None, 2, 2), dtype=float32, sparse=None, name=keras_tensor_6> </code></pre></div> </div> <h3 id="multioutput-model-evaluate">Multi-output model <code>evaluate()</code></h3> <p>The <code>evaluate()</code> method of a multi-output model no longer returns individual output losses separately. Instead, you should utilize the <code>metrics</code> argument in the <code>compile()</code> method to keep track of these losses.</p> <p>When dealing with multiple named outputs, such as output_a and output_b, the legacy <a href="https://www.tensorflow.org/api_docs/python/tf/keras"><code>tf.keras</code></a> would include <output_a>_loss, <output_b>_loss, and similar entries in metrics. However, in keras 3.0, these entries are not automatically added to metrics. They must be explicitly provided in the metrics list for each individual output.</p> <p>The following snippet of code will reproduce the above behavior:</p> <div class="codehilite"><pre><span></span><code><span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">layers</span> <span class="c1"># A functional model with multiple outputs</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,))</span> <span class="n">x1</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s1">'relu'</span><span class="p">)(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">x2</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s1">'relu'</span><span class="p">)(</span><span class="n">x1</span><span class="p">)</span> <span class="n">output_1</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s1">'softmax'</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"output_1"</span><span class="p">)(</span><span class="n">x1</span><span class="p">)</span> <span class="n">output_2</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s1">'softmax'</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"output_2"</span><span class="p">)(</span><span class="n">x2</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="o">=</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="o">=</span><span class="p">[</span><span class="n">output_1</span><span class="p">,</span> <span class="n">output_2</span><span class="p">])</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">optimizer</span><span class="o">=</span><span class="s1">'adam'</span><span class="p">,</span> <span class="n">loss</span><span class="o">=</span><span class="s1">'categorical_crossentropy'</span><span class="p">)</span> <span class="c1"># dummy data</span> <span class="n">x_test</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">[</span><span class="mi">10</span><span class="p">,</span> <span class="mi">10</span><span class="p">])</span> <span class="n">y_test</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">[</span><span class="mi">10</span><span class="p">,</span> <span class="mi">5</span><span class="p">])</span> <span class="n">model</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">)</span> </code></pre></div> <div class="codehilite"><pre><span></span><code><span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">layers</span> <span class="c1"># A functional model with multiple outputs</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,))</span> <span class="n">x1</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">x2</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)(</span><span class="n">x1</span><span class="p">)</span> <span class="n">output_1</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"softmax"</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"output_1"</span><span class="p">)(</span><span class="n">x1</span><span class="p">)</span> <span class="n">output_2</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"softmax"</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"output_2"</span><span class="p">)(</span><span class="n">x2</span><span class="p">)</span> <span class="c1"># dummy data</span> <span class="n">x_test</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">[</span><span class="mi">10</span><span class="p">,</span> <span class="mi">10</span><span class="p">])</span> <span class="n">y_test</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">[</span><span class="mi">10</span><span class="p">,</span> <span class="mi">5</span><span class="p">])</span> <span class="n">multi_output_model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="o">=</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="o">=</span><span class="p">[</span><span class="n">output_1</span><span class="p">,</span> <span class="n">output_2</span><span class="p">])</span> <span class="n">multi_output_model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">optimizer</span><span class="o">=</span><span class="s2">"adam"</span><span class="p">,</span> <span class="n">loss</span><span class="o">=</span><span class="s2">"categorical_crossentropy"</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="s2">"categorical_crossentropy"</span><span class="p">,</span> <span class="s2">"categorical_crossentropy"</span><span class="p">],</span> <span class="p">)</span> <span class="n">multi_output_model</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 112ms/step - loss: 4.0217 - output_1_categorical_crossentropy: 4.0217 [4.021683692932129, 4.021683692932129] </code></pre></div> </div> <h3 id="tensorflow-variables-tracking">TensorFlow variables tracking</h3> <p>Setting a <a href="https://www.tensorflow.org/api_docs/python/tf/Variable"><code>tf.Variable</code></a> as an attribute of a Keras 3 layer or model will not automatically track the variable, unlike in Keras 2. The following snippet of code will show that the <a href="https://www.tensorflow.org/api_docs/python/tf/Variables"><code>tf.Variables</code></a> are not being tracked.</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">MyCustomLayer</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">units</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">units</span> <span class="o">=</span> <span class="n">units</span> <span class="k">def</span> <span class="nf">build</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_shape</span><span class="p">):</span> <span class="n">input_dim</span> <span class="o">=</span> <span class="n">input_shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="bp">self</span><span class="o">.</span><span class="n">w</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">Variable</span><span class="p">(</span><span class="n">initial_value</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">input_dim</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">units</span><span class="p">]))</span> <span class="bp">self</span><span class="o">.</span><span class="n">b</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">Variable</span><span class="p">(</span><span class="n">initial_value</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="bp">self</span><span class="o">.</span><span class="n">units</span><span class="p">,]))</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="k">return</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">w</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">b</span> <span class="n">layer</span> <span class="o">=</span> <span class="n">MyCustomLayer</span><span class="p">(</span><span class="mi">3</span><span class="p">)</span> <span class="n">data</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">[</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">])</span> <span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">Sequential</span><span class="p">([</span><span class="n">layer</span><span class="p">])</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">optimizer</span><span class="o">=</span><span class="s2">"adam"</span><span class="p">,</span> <span class="n">loss</span><span class="o">=</span><span class="s2">"mse"</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> <span class="c1"># The model does not have any trainable variables</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="n">model</span><span class="o">.</span><span class="n">layers</span><span class="p">:</span> <span class="nb">print</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">trainable_variables</span><span class="p">)</span> </code></pre></div> <p>You will see the following warning:</p> <div class="codehilite"><pre><span></span><code>UserWarning: The model does not have any trainable weights. warnings.warn("The model does not have any trainable weights.") </code></pre></div> <p><strong>How to fix it:</strong> use <code>self.add_weight()</code> method or opt for a <code>keras.Variable</code> instead. If you are currently using <a href="https://www.tensorflow.org/api_docs/python/tf/variable"><code>tf.variable</code></a>, you can switch to <code>keras.Variable</code>.</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">MyCustomLayer</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">units</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">units</span> <span class="o">=</span> <span class="n">units</span> <span class="k">def</span> <span class="nf">build</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_shape</span><span class="p">):</span> <span class="n">input_dim</span> <span class="o">=</span> <span class="n">input_shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="bp">self</span><span class="o">.</span><span class="n">w</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">add_weight</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">[</span><span class="n">input_dim</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">units</span><span class="p">],</span> <span class="n">initializer</span><span class="o">=</span><span class="s2">"zeros"</span><span class="p">,</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">b</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">add_weight</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">[</span> <span class="bp">self</span><span class="o">.</span><span class="n">units</span><span class="p">,</span> <span class="p">],</span> <span class="n">initializer</span><span class="o">=</span><span class="s2">"zeros"</span><span class="p">,</span> <span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="k">return</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">w</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">b</span> <span class="n">layer</span> <span class="o">=</span> <span class="n">MyCustomLayer</span><span class="p">(</span><span class="mi">3</span><span class="p">)</span> <span class="n">data</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">[</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">])</span> <span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">Sequential</span><span class="p">([</span><span class="n">layer</span><span class="p">])</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">optimizer</span><span class="o">=</span><span class="s2">"adam"</span><span class="p">,</span> <span class="n">loss</span><span class="o">=</span><span class="s2">"mse"</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> <span class="c1"># Verify that the variables are now being tracked</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="n">model</span><span class="o">.</span><span class="n">layers</span><span class="p">:</span> <span class="nb">print</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">trainable_variables</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 33ms/step [<KerasVariable shape=(3, 3), dtype=float32, path=sequential_2/my_custom_layer_1/variable>, <KerasVariable shape=(3,), dtype=float32, path=sequential_2/my_custom_layer_1/variable_1>] </code></pre></div> </div> <h3 id="none-entries-in-nested-call-arguments"><code>None</code> entries in nested <code>call()</code> arguments</h3> <p><code>None</code> entries are not allowed as part of nested (e.g. list/tuples) tensor arguments in <code>Layer.call()</code>, nor as part of <code>call()</code>'s nested return values.</p> <p>If the <code>None</code> in the argument is intentional and serves a specific purpose, ensure that the argument is optional and structure it as a separate parameter. For example, consider defining the <code>call</code> method with optional argument.</p> <p>The following snippet of code will reproduce the error.</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">CustomLayer</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="n">foo</span> <span class="o">=</span> <span class="n">inputs</span><span class="p">[</span><span class="s2">"foo"</span><span class="p">]</span> <span class="n">baz</span> <span class="o">=</span> <span class="n">inputs</span><span class="p">[</span><span class="s2">"bar"</span><span class="p">][</span><span class="s2">"baz"</span><span class="p">]</span> <span class="k">if</span> <span class="n">baz</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> <span class="k">return</span> <span class="n">foo</span> <span class="o">+</span> <span class="n">baz</span> <span class="k">return</span> <span class="n">foo</span> <span class="n">layer</span> <span class="o">=</span> <span class="n">CustomLayer</span><span class="p">()</span> <span class="n">inputs</span> <span class="o">=</span> <span class="p">{</span> <span class="s2">"foo"</span><span class="p">:</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"foo"</span><span class="p">),</span> <span class="s2">"bar"</span><span class="p">:</span> <span class="p">{</span> <span class="s2">"baz"</span><span class="p">:</span> <span class="kc">None</span><span class="p">,</span> <span class="p">},</span> <span class="p">}</span> <span class="n">layer</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> </code></pre></div> <p><strong>How to fix it:</strong></p> <p><strong>Solution 1:</strong> Replace <code>None</code> with a value, like this:</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">CustomLayer</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="n">foo</span> <span class="o">=</span> <span class="n">inputs</span><span class="p">[</span><span class="s2">"foo"</span><span class="p">]</span> <span class="n">baz</span> <span class="o">=</span> <span class="n">inputs</span><span class="p">[</span><span class="s2">"bar"</span><span class="p">][</span><span class="s2">"baz"</span><span class="p">]</span> <span class="k">return</span> <span class="n">foo</span> <span class="o">+</span> <span class="n">baz</span> <span class="n">layer</span> <span class="o">=</span> <span class="n">CustomLayer</span><span class="p">()</span> <span class="n">inputs</span> <span class="o">=</span> <span class="p">{</span> <span class="s2">"foo"</span><span class="p">:</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"foo"</span><span class="p">),</span> <span class="s2">"bar"</span><span class="p">:</span> <span class="p">{</span> <span class="s2">"baz"</span><span class="p">:</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"bar"</span><span class="p">),</span> <span class="p">},</span> <span class="p">}</span> <span class="n">layer</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code><KerasTensor shape=(None, 1), dtype=float32, sparse=False, name=keras_tensor_14> </code></pre></div> </div> <p><strong>Solution 2:</strong> Define the call method with an optional argument. Here is an example of this fix:</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">CustomLayer</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">foo</span><span class="p">,</span> <span class="n">baz</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="k">if</span> <span class="n">baz</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> <span class="k">return</span> <span class="n">foo</span> <span class="o">+</span> <span class="n">baz</span> <span class="k">return</span> <span class="n">foo</span> <span class="n">layer</span> <span class="o">=</span> <span class="n">CustomLayer</span><span class="p">()</span> <span class="n">foo</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"foo"</span><span class="p">)</span> <span class="n">baz</span> <span class="o">=</span> <span class="kc">None</span> <span class="n">layer</span><span class="p">(</span><span class="n">foo</span><span class="p">,</span> <span class="n">baz</span><span class="o">=</span><span class="n">baz</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code><KerasTensor shape=(None, 1), dtype=float32, sparse=False, name=keras_tensor_15> </code></pre></div> </div> <h3 id="statebuilding-issues">State-building issues</h3> <p>Keras 3 is significantly stricter than Keras 2 about when state (e.g. numerical weight variables) can be created. Keras 3 wants all state to be created before the model can be trained. This is a requirement for using JAX (whereas TensorFlow was very lenient about state creation timing).</p> <p>Keras layers should create their state either in their constructor (<code>__init__()</code> method) or in their <code>build()</code> method. They should avoid creating state in <code>call()</code>.</p> <p>If you ignore this recommendation and create state in <code>call()</code> anyway (e.g. by calling a previously unbuilt layer), then Keras will attempt to build the layer automatically by calling the <code>call()</code> method on symbolic inputs before training. However, this attempt at automatic state creation may fail in certain cases. This will cause an error that looks like like this:</p> <div class="codehilite"><pre><span></span><code>Layer 'frame_position_embedding' looks like it has unbuilt state, but Keras is not able to trace the layer `call()` in order to build it automatically. Possible causes: 1. The `call()` method of your layer may be crashing. Try to `__call__()` the layer eagerly on some test input first to see if it works. E.g. `x = np.random.random((3, 4)); y = layer(x)` 2. If the `call()` method is correct, then you may need to implement the `def build(self, input_shape)` method on your layer. It should create all variables used by the layer (e.g. by calling `layer.build()` on all its children layers). </code></pre></div> <p>You could reproduce this error with the following layer, when used with the JAX backend:</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">PositionalEmbedding</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sequence_length</span><span class="p">,</span> <span class="n">output_dim</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">position_embeddings</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span> <span class="n">input_dim</span><span class="o">=</span><span class="n">sequence_length</span><span class="p">,</span> <span class="n">output_dim</span><span class="o">=</span><span class="n">output_dim</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">sequence_length</span> <span class="o">=</span> <span class="n">sequence_length</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_dim</span> <span class="o">=</span> <span class="n">output_dim</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">compute_dtype</span><span class="p">)</span> <span class="n">length</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">inputs</span><span class="p">)[</span><span class="mi">1</span><span class="p">]</span> <span class="n">positions</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">start</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">stop</span><span class="o">=</span><span class="n">length</span><span class="p">,</span> <span class="n">step</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="n">embedded_positions</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">position_embeddings</span><span class="p">(</span><span class="n">positions</span><span class="p">)</span> <span class="k">return</span> <span class="n">inputs</span> <span class="o">+</span> <span class="n">embedded_positions</span> </code></pre></div> <p><strong>How to fix it:</strong> Do exactly what the error message asks. First, try to run the layer eagerly to see if the <code>call()</code> method is in fact correct (note: if it was working in Keras 2, then it is correct and does not need to be changed). If it is indeed correct, then you should implement a <code>build(self, input_shape)</code> method that creates all of the layer's state, including the state of sublayers. Here's the fix as applied for the layer above (note the <code>build()</code> method):</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">PositionalEmbedding</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sequence_length</span><span class="p">,</span> <span class="n">output_dim</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">position_embeddings</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span> <span class="n">input_dim</span><span class="o">=</span><span class="n">sequence_length</span><span class="p">,</span> <span class="n">output_dim</span><span class="o">=</span><span class="n">output_dim</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">sequence_length</span> <span class="o">=</span> <span class="n">sequence_length</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_dim</span> <span class="o">=</span> <span class="n">output_dim</span> <span class="k">def</span> <span class="nf">build</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_shape</span><span class="p">):</span> <span class="bp">self</span><span class="o">.</span><span class="n">position_embeddings</span><span class="o">.</span><span class="n">build</span><span class="p">(</span><span class="n">input_shape</span><span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">compute_dtype</span><span class="p">)</span> <span class="n">length</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">inputs</span><span class="p">)[</span><span class="mi">1</span><span class="p">]</span> <span class="n">positions</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">start</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">stop</span><span class="o">=</span><span class="n">length</span><span class="p">,</span> <span class="n">step</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="n">embedded_positions</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">position_embeddings</span><span class="p">(</span><span class="n">positions</span><span class="p">)</span> <span class="k">return</span> <span class="n">inputs</span> <span class="o">+</span> <span class="n">embedded_positions</span> </code></pre></div> <h3 id="removed-features">Removed features</h3> <p>A small number of legacy features with very low usage were removed from Keras 3 as a cleanup measure:</p> <ul> <li><code>keras.layers.ThresholdedReLU</code> is removed. Instead, you can simply use the <code>ReLU</code> layer with the argument <code>threshold</code>.</li> <li>Symbolic <code>Layer.add_loss()</code>: Symbolic <code>add_loss()</code> is removed (you can still use <code>add_loss()</code> inside the <code>call()</code> method of a layer/model).</li> <li>Locally connected layers (<code>LocallyConnected1D</code>, <code>LocallyConnected2D</code> are removed due to very low usage. To use locally connected layers, copy the layer implementation into your own codebase.</li> <li><code>keras.layers.experimental.RandomFourierFeatures</code> is removed due to very low usage. To use it, copy the layer implementation into your own codebase.</li> <li>Removed layer attributes: Layer attributes <code>metrics</code>, <code>dynamic</code> are removed. <code>metrics</code> is still available on the <code>Model</code> class.</li> <li>The <code>constants</code> and <code>time_major</code> arguments in RNN layers are removed. The <code>constants</code> argument was a remnant of Theano and had very low usage. The <code>time_major</code> argument also had very low usage.</li> <li><code>reset_metrics</code> argument: The <code>reset_metrics</code> argument is removed from <code>model.*_on_batch()</code> methods. This argument had very low usage.</li> <li>The <code>keras.constraints.RadialConstraint</code> object is removed. This object had very low usage.</li> </ul> <hr /> <h2 id="transitioning-to-backendagnostic-keras-3">Transitioning to backend-agnostic Keras 3</h2> <p>Keras 3 code with the TensorFlow backend will work with native TensorFlow APIs. However, if you want your code to be backend-agnostic, you will need to:</p> <ul> <li>Replace all of the <a href="https://www.tensorflow.org/api_docs/python/tf/*"><code>tf.*</code></a> API calls with their equivalent Keras APIs.</li> <li>Convert your custom <code>train_step</code>/<code>test_step</code> methods to a multi-framework implementation.</li> <li>Make sure you're using stateless <code>keras.random</code> ops correctly in your layers.</li> </ul> <p>Let's go over each point in detail.</p> <h3 id="switching-to-keras-ops">Switching to Keras ops</h3> <p>In many cases, this is the only thing you need to do to start being able to run your custom layers and metrics with JAX and PyTorch: replace any <a href="https://www.tensorflow.org/api_docs/python/tf/*"><code>tf.*</code></a>, <a href="https://www.tensorflow.org/api_docs/python/tf/math*"><code>tf.math*</code></a>, <a href="https://www.tensorflow.org/api_docs/python/tf/linalg/*"><code>tf.linalg.*</code></a>, etc. with <code>keras.ops.*</code>. Most TF ops should be consistent with Keras 3. If the names different, they will be highlighted in this guide.</p> <h4 id="numpy-ops">NumPy ops</h4> <p>Keras implements the NumPy API as part of <code>keras.ops</code>.</p> <p>The table below only lists a small subset of TensorFlow and Keras ops; ops not listed are usually named the same in both frameworks (e.g. <code>reshape</code>, <code>matmul</code>, <code>cast</code>, etc.)</p> <table> <thead> <tr> <th>TensorFlow</th> <th>Keras 3.0</th> </tr> </thead> <tbody> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/abs"><code>tf.abs</code></a></td> <td><a href="/api/ops/numpy#absolute-function"><code>keras.ops.absolute</code></a></td> </tr> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/reduce_all"><code>tf.reduce_all</code></a></td> <td><a href="/api/ops/numpy#all-function"><code>keras.ops.all</code></a></td> </tr> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/reduce_max"><code>tf.reduce_max</code></a></td> <td><a href="/api/ops/numpy#amax-function"><code>keras.ops.amax</code></a></td> </tr> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/reduce_min"><code>tf.reduce_min</code></a></td> <td><a href="/api/ops/numpy#amin-function"><code>keras.ops.amin</code></a></td> </tr> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/reduce_any"><code>tf.reduce_any</code></a></td> <td><a href="/api/ops/numpy#any-function"><code>keras.ops.any</code></a></td> </tr> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/concat"><code>tf.concat</code></a></td> <td><a href="/api/ops/numpy#concatenate-function"><code>keras.ops.concatenate</code></a></td> </tr> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/range"><code>tf.range</code></a></td> <td><a href="/api/ops/numpy#arange-function"><code>keras.ops.arange</code></a></td> </tr> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/acos"><code>tf.acos</code></a></td> <td><a href="/api/ops/numpy#arccos-function"><code>keras.ops.arccos</code></a></td> </tr> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/asin"><code>tf.asin</code></a></td> <td><a href="/api/ops/numpy#arcsin-function"><code>keras.ops.arcsin</code></a></td> </tr> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/asinh"><code>tf.asinh</code></a></td> <td><a href="/api/ops/numpy#arcsinh-function"><code>keras.ops.arcsinh</code></a></td> </tr> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/atan"><code>tf.atan</code></a></td> <td><a href="/api/ops/numpy#arctan-function"><code>keras.ops.arctan</code></a></td> </tr> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/atan2"><code>tf.atan2</code></a></td> <td><a href="/api/ops/numpy#arctan2-function"><code>keras.ops.arctan2</code></a></td> </tr> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/atanh"><code>tf.atanh</code></a></td> <td><a href="/api/ops/numpy#arctanh-function"><code>keras.ops.arctanh</code></a></td> </tr> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/convert_to_tensor"><code>tf.convert_to_tensor</code></a></td> <td><a href="/api/ops/core#converttotensor-function"><code>keras.ops.convert_to_tensor</code></a></td> </tr> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/reduce_mean"><code>tf.reduce_mean</code></a></td> <td><a href="/api/ops/numpy#mean-function"><code>keras.ops.mean</code></a></td> </tr> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/clip_by_value"><code>tf.clip_by_value</code></a></td> <td><a href="/api/ops/numpy#clip-function"><code>keras.ops.clip</code></a></td> </tr> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/math/conj"><code>tf.math.conj</code></a></td> <td><a href="/api/ops/numpy#conjugate-function"><code>keras.ops.conjugate</code></a></td> </tr> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/linalg/diag_part"><code>tf.linalg.diag_part</code></a></td> <td><a href="/api/ops/numpy#diagonal-function"><code>keras.ops.diagonal</code></a></td> </tr> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/reverse"><code>tf.reverse</code></a></td> <td><a href="/api/ops/numpy#flip-function"><code>keras.ops.flip</code></a></td> </tr> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/gather"><code>tf.gather</code></a></td> <td><a href="/api/ops/numpy#take-function"><code>keras.ops.take</code></a></td> </tr> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/math/is_finite"><code>tf.math.is_finite</code></a></td> <td><a href="/api/ops/numpy#isfinite-function"><code>keras.ops.isfinite</code></a></td> </tr> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/math/is_inf"><code>tf.math.is_inf</code></a></td> <td><a href="/api/ops/numpy#isinf-function"><code>keras.ops.isinf</code></a></td> </tr> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/math/is_nan"><code>tf.math.is_nan</code></a></td> <td><a href="/api/ops/numpy#isnan-function"><code>keras.ops.isnan</code></a></td> </tr> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/reduce_max"><code>tf.reduce_max</code></a></td> <td><a href="/api/ops/numpy#max-function"><code>keras.ops.max</code></a></td> </tr> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/reduce_mean"><code>tf.reduce_mean</code></a></td> <td><a href="/api/ops/numpy#mean-function"><code>keras.ops.mean</code></a></td> </tr> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/reduce_min"><code>tf.reduce_min</code></a></td> <td><a href="/api/ops/numpy#min-function"><code>keras.ops.min</code></a></td> </tr> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/rank"><code>tf.rank</code></a></td> <td><a href="/api/ops/numpy#ndim-function"><code>keras.ops.ndim</code></a></td> </tr> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/math/pow"><code>tf.math.pow</code></a></td> <td><a href="/api/ops/numpy#power-function"><code>keras.ops.power</code></a></td> </tr> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/reduce_prod"><code>tf.reduce_prod</code></a></td> <td><a href="/api/ops/numpy#prod-function"><code>keras.ops.prod</code></a></td> </tr> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/math/reduce_std"><code>tf.math.reduce_std</code></a></td> <td><a href="/api/ops/numpy#std-function"><code>keras.ops.std</code></a></td> </tr> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/reduce_sum"><code>tf.reduce_sum</code></a></td> <td><a href="/api/ops/numpy#sum-function"><code>keras.ops.sum</code></a></td> </tr> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/gather"><code>tf.gather</code></a></td> <td><a href="/api/ops/numpy#take-function"><code>keras.ops.take</code></a></td> </tr> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/gather_nd"><code>tf.gather_nd</code></a></td> <td><a href="/api/ops/numpy#takealongaxis-function"><code>keras.ops.take_along_axis</code></a></td> </tr> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/math/reduce_variance"><code>tf.math.reduce_variance</code></a></td> <td><a href="/api/ops/numpy#var-function"><code>keras.ops.var</code></a></td> </tr> </tbody> </table> <h4 id="others-ops">Others ops</h4> <table> <thead> <tr> <th>TensorFlow</th> <th>Keras 3.0</th> </tr> </thead> <tbody> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits"><code>tf.nn.sigmoid_cross_entropy_with_logits</code></a></td> <td><a href="/api/ops/nn#binarycrossentropy-function"><code>keras.ops.binary_crossentropy</code></a> (mind the <code>from_logits</code> argument)</td> </tr> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/nn/sparse_softmax_cross_entropy_with_logits"><code>tf.nn.sparse_softmax_cross_entropy_with_logits</code></a></td> <td><a href="/api/ops/nn#sparsecategoricalcrossentropy-function"><code>keras.ops.sparse_categorical_crossentropy</code></a> (mind the <code>from_logits</code> argument)</td> </tr> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/nn/sparse_softmax_cross_entropy_with_logits"><code>tf.nn.sparse_softmax_cross_entropy_with_logits</code></a></td> <td><code>keras.ops.categorical_crossentropy(target, output, from_logits=False, axis=-1)</code></td> </tr> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/nn/conv1d"><code>tf.nn.conv1d</code></a>, <a href="https://www.tensorflow.org/api_docs/python/tf/nn/conv2d"><code>tf.nn.conv2d</code></a>, <a href="https://www.tensorflow.org/api_docs/python/tf/nn/conv3d"><code>tf.nn.conv3d</code></a>, <a href="https://www.tensorflow.org/api_docs/python/tf/nn/convolution"><code>tf.nn.convolution</code></a></td> <td><a href="/api/ops/nn#conv-function"><code>keras.ops.conv</code></a></td> </tr> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/nn/conv_transpose"><code>tf.nn.conv_transpose</code></a>, <a href="https://www.tensorflow.org/api_docs/python/tf/nn/conv1d_transpose"><code>tf.nn.conv1d_transpose</code></a>, <a href="https://www.tensorflow.org/api_docs/python/tf/nn/conv2d_transpose"><code>tf.nn.conv2d_transpose</code></a>, <a href="https://www.tensorflow.org/api_docs/python/tf/nn/conv3d_transpose"><code>tf.nn.conv3d_transpose</code></a></td> <td><a href="/api/ops/nn#convtranspose-function"><code>keras.ops.conv_transpose</code></a></td> </tr> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d"><code>tf.nn.depthwise_conv2d</code></a></td> <td><a href="/api/ops/nn#depthwiseconv-function"><code>keras.ops.depthwise_conv</code></a></td> </tr> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/nn/separable_conv2d"><code>tf.nn.separable_conv2d</code></a></td> <td><a href="/api/ops/nn#separableconv-function"><code>keras.ops.separable_conv</code></a></td> </tr> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/nn/batch_normalization"><code>tf.nn.batch_normalization</code></a></td> <td>No direct equivalent; use <a href="/api/layers/normalization_layers/batch_normalization#batchnormalization-class"><code>keras.layers.BatchNormalization</code></a></td> </tr> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/nn/dropout"><code>tf.nn.dropout</code></a></td> <td><a href="/api/random/random_ops#dropout-function"><code>keras.random.dropout</code></a></td> </tr> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/nn/embedding_lookup"><code>tf.nn.embedding_lookup</code></a></td> <td><a href="/api/ops/numpy#take-function"><code>keras.ops.take</code></a></td> </tr> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/nn/l2_normalize"><code>tf.nn.l2_normalize</code></a></td> <td><a href="/api/utils/python_utils#normalize-function"><code>keras.utils.normalize</code></a> (not an op)</td> </tr> <tr> <td><code>x.numpy</code></td> <td><a href="/api/ops/core#converttonumpy-function"><code>keras.ops.convert_to_numpy</code></a></td> </tr> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/scatter_nd_update"><code>tf.scatter_nd_update</code></a></td> <td><a href="/api/ops/core#scatterupdate-function"><code>keras.ops.scatter_update</code></a></td> </tr> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/tensor_scatter_nd_update"><code>tf.tensor_scatter_nd_update</code></a></td> <td><a href="/api/ops/core#sliceupdate-function"><code>keras.ops.slice_update</code></a></td> </tr> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/signal/fft2d"><code>tf.signal.fft2d</code></a></td> <td><a href="/api/ops/fft#fft2-function"><code>keras.ops.fft2</code></a></td> </tr> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/signal/inverse_stft"><code>tf.signal.inverse_stft</code></a></td> <td><a href="/api/ops/fft#istft-function"><code>keras.ops.istft</code></a></td> </tr> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/image/crop_to_bounding_box"><code>tf.image.crop_to_bounding_box</code></a></td> <td><a href="/api/ops/image#cropimages-function"><code>keras.ops.image.crop_images</code></a></td> </tr> <tr> <td><a href="https://www.tensorflow.org/api_docs/python/tf/image/pad_to_bounding_box"><code>tf.image.pad_to_bounding_box</code></a></td> <td><a href="/api/ops/image#padimages-function"><code>keras.ops.image.pad_images</code></a></td> </tr> </tbody> </table> <h3 id="custom-trainstep-methods">Custom <code>train_step()</code> methods</h3> <p>Your models may include a custom <code>train_step()</code> or <code>test_step()</code> method, which rely on TensorFlow-only APIs – for instance, your <code>train_step()</code> method may leverage TensorFlow's <a href="https://www.tensorflow.org/api_docs/python/tf/GradientTape"><code>tf.GradientTape</code></a>. To convert such models to run on JAX or PyTorch, you will have a write a different <code>train_step()</code> implementation for each backend you want to support.</p> <p>In some cases, you might be able to simply override the <code>Model.compute_loss()</code> method and make it fully backend-agnostic, instead of overriding <code>train_step()</code>. Here's an example of a layer with a custom <code>compute_loss()</code> method which works across JAX, TensorFlow, and PyTorch:</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">MyModel</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">):</span> <span class="k">def</span> <span class="nf">compute_loss</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">y</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">y_pred</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">sample_weight</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">mean_squared_error</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">,</span> <span class="n">sample_weight</span><span class="p">))</span> <span class="k">return</span> <span class="n">loss</span> </code></pre></div> <p>If you need to modify the optimization mechanism itself, beyond the loss computation, then you will need to override <code>train_step()</code>, and implement one <code>train_step</code> method per backend, like below.</p> <p>See the following guides for details on how each backend should be handled:</p> <ul> <li><a href="https://keras.io/guides/custom_train_step_in_jax/">Customizing what happens in <code>fit()</code> with JAX</a></li> <li><a href="https://keras.io/guides/custom_train_step_in_tensorflow/">Customizing what happens in <code>fit()</code> with TensorFlow</a></li> <li><a href="https://keras.io/guides/custom_train_step_in_torch/">Customizing what happens in <code>fit()</code> with PyTorch</a></li> </ul> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">MyModel</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">):</span> <span class="k">def</span> <span class="nf">train_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="k">if</span> <span class="n">keras</span><span class="o">.</span><span class="n">backend</span><span class="o">.</span><span class="n">backend</span><span class="p">()</span> <span class="o">==</span> <span class="s2">"jax"</span><span class="p">:</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_jax_train_step</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="k">elif</span> <span class="n">keras</span><span class="o">.</span><span class="n">backend</span><span class="o">.</span><span class="n">backend</span><span class="p">()</span> <span class="o">==</span> <span class="s2">"tensorflow"</span><span class="p">:</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_tensorflow_train_step</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="k">elif</span> <span class="n">keras</span><span class="o">.</span><span class="n">backend</span><span class="o">.</span><span class="n">backend</span><span class="p">()</span> <span class="o">==</span> <span class="s2">"torch"</span><span class="p">:</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">_torch_train_step</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="k">def</span> <span class="nf">_jax_train_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">state</span><span class="p">,</span> <span class="n">data</span><span class="p">):</span> <span class="k">pass</span> <span class="c1"># See guide: keras.io/guides/custom_train_step_in_jax/</span> <span class="k">def</span> <span class="nf">_tensorflow_train_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">):</span> <span class="k">pass</span> <span class="c1"># See guide: keras.io/guides/custom_train_step_in_tensorflow/</span> <span class="k">def</span> <span class="nf">_torch_train_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">):</span> <span class="k">pass</span> <span class="c1"># See guide: keras.io/guides/custom_train_step_in_torch/</span> </code></pre></div> <h3 id="rngusing-layers">RNG-using layers</h3> <p>Keras 3 has a new <code>keras.random</code> namespace, containing:</p> <ul> <li><a href="/api/random/random_ops#normal-function"><code>keras.random.normal</code></a></li> <li><a href="/api/random/random_ops#uniform-function"><code>keras.random.uniform</code></a></li> <li><a href="/api/random/random_ops#shuffle-function"><code>keras.random.shuffle</code></a></li> <li>etc.</li> </ul> <p>These operations are <strong>stateless</strong>, which means that if you pass a <code>seed</code> argument, they will return the same result every time. Like this:</p> <div class="codehilite"><pre><span></span><code><span class="nb">print</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(),</span> <span class="n">seed</span><span class="o">=</span><span class="mi">123</span><span class="p">))</span> <span class="nb">print</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(),</span> <span class="n">seed</span><span class="o">=</span><span class="mi">123</span><span class="p">))</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>tf.Tensor(0.7832616, shape=(), dtype=float32) tf.Tensor(0.7832616, shape=(), dtype=float32) </code></pre></div> </div> <p>Crucially, this differs from the behavior of stateful <a href="https://www.tensorflow.org/api_docs/python/tf/random"><code>tf.random</code></a> ops:</p> <div class="codehilite"><pre><span></span><code><span class="nb">print</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(),</span> <span class="n">seed</span><span class="o">=</span><span class="mi">123</span><span class="p">))</span> <span class="nb">print</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(),</span> <span class="n">seed</span><span class="o">=</span><span class="mi">123</span><span class="p">))</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>tf.Tensor(2.4435377, shape=(), dtype=float32) tf.Tensor(-0.6386405, shape=(), dtype=float32) </code></pre></div> </div> <p>When you write a RNG-using layer, such as a custom dropout layer, you are going to want to use a different seed value at layer call. However, you cannot just increment a Python integer and pass it, because while this would work fine when executed eagerly, it would not work as expected when using compilation (which is available with JAX, TensorFlow, and PyTorch). When compiling the layer, the first Python integer seed value seen by the layer would be hardcoded into the compiled graph.</p> <p>To address this, you should pass as the <code>seed</code> argument an instance of a stateful <a href="/api/random/seed_generator#seedgenerator-class"><code>keras.random.SeedGenerator</code></a> object, like this:</p> <div class="codehilite"><pre><span></span><code><span class="n">seed_generator</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">SeedGenerator</span><span class="p">(</span><span class="mi">1337</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(),</span> <span class="n">seed</span><span class="o">=</span><span class="n">seed_generator</span><span class="p">))</span> <span class="nb">print</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(),</span> <span class="n">seed</span><span class="o">=</span><span class="n">seed_generator</span><span class="p">))</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>tf.Tensor(0.6077996, shape=(), dtype=float32) tf.Tensor(0.8211102, shape=(), dtype=float32) </code></pre></div> </div> <p>So when writing a RNG using layer, you would use the following pattern:</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">RandomNoiseLayer</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">noise_rate</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">noise_rate</span> <span class="o">=</span> <span class="n">noise_rate</span> <span class="bp">self</span><span class="o">.</span><span class="n">seed_generator</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">SeedGenerator</span><span class="p">(</span><span class="mi">1337</span><span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="n">noise</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span> <span class="n">minval</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">maxval</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">noise_rate</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">seed_generator</span> <span class="p">)</span> <span class="k">return</span> <span class="n">inputs</span> <span class="o">+</span> <span class="n">noise</span> </code></pre></div> <p>Such a layer is safe to use in any setting – in eager execution or in a compiled model. Each layer call will be using a different seed value, as expected.</p> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#migrating-keras-2-code-to-multibackend-keras-3'>Migrating Keras 2 code to multi-backend Keras 3</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#setup'>Setup</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#going-from-keras-2-to-keras-3-with-the-tensorflow-backend'>Going from Keras 2 to Keras 3 with the TensorFlow backend</a> </div> <div class='k-outline-depth-3'> <a href='#jitcompile-is-set-to-true-by-default-on-gpu'><code>jit_compile</code> is set to <code>True</code> by default on GPU.</a> </div> <div class='k-outline-depth-3'> <a href='#saving-a-model-in-the-tf-savedmodel-format'>Saving a model in the TF SavedModel format</a> </div> <div class='k-outline-depth-3'> <a href='#loading-a-tf-savedmodel'>Loading a TF SavedModel</a> </div> <div class='k-outline-depth-3'> <a href='#using-deeply-nested-inputs-in-functional-models'>Using deeply nested inputs in Functional Models</a> </div> <div class='k-outline-depth-3'> <a href='#tf-autograph'>TF autograph</a> </div> <div class='k-outline-depth-3'> <a href='#calling-tf-ops-with-a-kerastensor'>Calling TF ops with a <code>KerasTensor</code></a> </div> <div class='k-outline-depth-3'> <a href='#multioutput-model-evaluate'>Multi-output model <code>evaluate()</code></a> </div> <div class='k-outline-depth-3'> <a href='#tensorflow-variables-tracking'>TensorFlow variables tracking</a> </div> <div class='k-outline-depth-3'> <a href='#none-entries-in-nested-call-arguments'><code>None</code> entries in nested <code>call()</code> arguments</a> </div> <div class='k-outline-depth-3'> <a href='#statebuilding-issues'>State-building issues</a> </div> <div class='k-outline-depth-3'> <a href='#removed-features'>Removed features</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#transitioning-to-backendagnostic-keras-3'>Transitioning to backend-agnostic Keras 3</a> </div> <div class='k-outline-depth-3'> <a href='#switching-to-keras-ops'>Switching to Keras ops</a> </div> <div class='k-outline-depth-3'> <a href='#custom-trainstep-methods'>Custom <code>train_step()</code> methods</a> </div> <div class='k-outline-depth-3'> <a href='#rngusing-layers'>RNG-using layers</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>