CINXE.COM
Keras: Deep Learning for humans
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no"> <meta name="description" content="Keras Core documentation"> <meta name="author" content="Keras Team"> <title>Keras: Deep Learning for humans</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css?family=Open+Sans:wght@300;400;500;600&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/landing.css" rel="stylesheet"> <link href="/css/announcement.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <!-- Masthead --> <header class="masthead text-center"> <div class="container"> <img src='/img/logo.png' class='logo' /> <div class="row"> <div class="col-xl-8 mx-auto"> <h1 class="mb-5">Introducing Keras 3.0</h1> <div class="row mx-auto"> <div class="col-md px-1"> <a href='/getting_started/' class="btn btn-block btn-lg btn-primary">Get started</a> </div> <div class="col-md px-1"> <a href='/api/' class="btn btn-block btn-lg btn-secondary">API docs</a> </div> <div class="col-md px-1"> <a href='/guides/' class="btn btn-block btn-lg btn-secondary">Guides</a> </div> <div class="col-md px-1"> <a href='https://github.com/keras-team/keras/' class="btn btn-block btn-lg btn-secondary">GitHub</a> </div> </div> </div> </div> </header> <div class="container"> <div class="row"> <div class="col-lg"> <div class="blog-content"> <p>After five months of extensive public beta testing, we're excited to announce the official release of Keras 3.0. Keras 3 is a full rewrite of Keras that enables you to run your Keras workflows on top of either JAX, TensorFlow, or PyTorch, and that unlocks brand new large-scale model training and deployment capabilities. You can pick the framework that suits you best, and switch from one to another based on your current goals. You can also use Keras as a low-level cross-framework language to develop custom components such as layers, models, or metrics that can be used in native workflows in JAX, TensorFlow, or PyTorch — with one codebase.</p> <hr /> <h2>Welcome to multi-framework machine learning.</h2> <p>You're already familiar with the benefits of using Keras — it enables high-velocity development via an obsessive focus on great UX, API design, and debuggability. It's also a battle-tested framework that has been chosen by over 2.5M developers and that powers some of the most sophisticated, largest-scale ML systems in the world, such as the Waymo self-driving fleet and the YouTube recommendation engine. But what are the additional benefits of using the new multi-backend Keras 3?</p> <ul> <li><strong>Always get the best performance for your models.</strong> In our benchmarks, we found that JAX typically delivers the best training and inference performance on GPU, TPU, and CPU — but results vary from model to model, as non-XLA TensorFlow is occasionally faster on GPU. The ability to dynamically select the backend that will deliver the best performance for your model <em>without having to change anything to your code</em> means you're guaranteed to train and serve with the highest achievable efficiency.</li> <li><strong>Unlock ecosystem optionality for your models.</strong> Any Keras 3 model can be instantiated as a PyTorch <code>Module</code>, can be exported as a TensorFlow <code>SavedModel</code>, or can be instantiated as a stateless JAX function. That means that you can use your Keras 3 models with PyTorch ecosystem packages, with the full range of TensorFlow deployment & production tools (like TF-Serving, TF.js and TFLite), and with JAX large-scale TPU training infrastructure. Write one <code>model.py</code> using Keras 3 APIs, and get access to everything the ML world has to offer.</li> <li><strong>Leverage large-scale model parallelism & data parallelism with JAX.</strong> Keras 3 includes a brand new distribution API, the <code>keras.distribution</code> namespace, currently implemented for the JAX backend (coming soon to the TensorFlow and PyTorch backends). It makes it easy to do model parallelism, data parallelism, and combinations of both — at arbitrary model scales and cluster scales. Because it keeps the model definition, training logic, and sharding configuration all separate from each other, it makes your distribution workflow easy to develop and easy to maintain. See our <a href="/guides/distribution/">starter guide</a>.</li> <li><strong>Maximize reach for your open-source model releases.</strong> Want to release a pretrained model? Want as many people as possible to be able to use it? If you implement it in pure TensorFlow or PyTorch, it will be usable by roughly half of the community. If you implement it in Keras 3, it is instantly usable by anyone regardless of their framework of choice (even if they're not Keras users themselves). Twice the impact at no added development cost.</li> <li><strong>Use data pipelines from any source.</strong> The Keras 3 <code>fit()</code>/<code>evaluate()</code>/<code>predict()</code> routines are compatible with <code>tf.data.Dataset</code> objects, with PyTorch <code>DataLoader</code> objects, with NumPy arrays, Pandas dataframes — regardless of the backend you're using. You can train a Keras 3 + TensorFlow model on a PyTorch <code>DataLoader</code> or train a Keras 3 + PyTorch model on a <code>tf.data.Dataset</code>.</li> </ul> <hr /> <h2>The full Keras API, available for JAX, TensorFlow, and PyTorch.</h2> <p>Keras 3 implements the full Keras API and makes it available with TensorFlow, JAX, and PyTorch — over a hundred layers, dozens of metrics, loss functions, optimizers, and callbacks, the Keras training and evaluation loops, and the Keras saving & serialization infrastructure. All the APIs you know and love are here.</p> <p>Any Keras model that only uses built-in layers will immediately work with all supported backends. In fact, your existing <code>tf.keras</code> models that only use built-in layers can start running in JAX and PyTorch <em>right away</em>! That's right, your codebase just gained a whole new set of capabilities.</p> <p><img class="irasto" src="https://s3.amazonaws.com/keras.io/img/keras_3/cross_framework_keras_3.jpg" /></p> <hr /> <h2>Author multi-framework layers, models, metrics...</h2> <p>Keras 3 enables you to create components (like arbitrary custom layers or pretrained models) that will work the same in any framework. In particular, Keras 3 gives you access to the <code>keras.ops</code> namespace that works across all backends. It contains:</p> <ul> <li><strong>A full implementation of the NumPy API.</strong> Not something "NumPy-like" — just literally the NumPy API, with the same functions and the same arguments. You get <code>ops.matmul</code>, <code>ops.sum</code>, <code>ops.stack</code>, <code>ops.einsum</code>, etc.</li> <li><strong>A set of neural network-specific functions</strong> that are absent from NumPy, such as <code>ops.softmax</code>, <code>ops.binary_crossentropy</code>, <code>ops.conv</code>, etc.</li> </ul> <p>As long as you only use ops from <code>keras.ops</code>, your custom layers, custom losses, custom metrics, and custom optimizers <strong>will work with JAX, PyTorch, and TensorFlow — with the same code</strong>. That means that you can maintain only one component implementation (e.g. a single <code>model.py</code> together with a single checkpoint file), and you can use it in all frameworks, with the exact same numerics.</p> <p><img class="irasto" src="https://s3.amazonaws.com/keras.io/img/keras_3/custom_component_authoring_keras_3.jpg" /></p> <hr /> <h2>...that works seamlessly with any JAX, TensorFlow, and PyTorch workflow.</h2> <p>Keras 3 is not just intended for Keras-centric workflows where you define a Keras model, a Keras optimizer, a Keras loss and metrics, and you call <code>fit()</code>, <code>evaluate()</code>, and <code>predict()</code>. It's also meant to work seamlessly with low-level backend-native workflows: you can take a Keras model (or any other component, such as a loss or metric) and start using it in a JAX training loop, a TensorFlow training loop, or a PyTorch training loop, or as part of a JAX or PyTorch model, with zero friction. Keras 3 provides exactly the same degree of low-level implementation flexibility in JAX and PyTorch as <code>tf.keras</code> previously did in TensorFlow.</p> <p>You can:</p> <ul> <li>Write a low-level JAX training loop to train a Keras model using an <code>optax</code> optimizer, <code>jax.grad</code>, <code>jax.jit</code>, <code>jax.pmap</code>.</li> <li>Write a low-level TensorFlow training loop to train a Keras model using <code>tf.GradientTape</code> and <code>tf.distribute</code>.</li> <li>Write a low-level PyTorch training loop to train a Keras model using a <code>torch.optim</code> optimizer, a <code>torch</code> loss function, and the <code>torch.nn.parallel.DistributedDataParallel</code> wrapper.</li> <li>Use Keras layers in a PyTorch <code>Module</code> (because they are <code>Module</code> instances too!)</li> <li>Use any PyTorch <code>Module</code> in a Keras model as if it were a Keras layer.</li> <li>etc.</li> </ul> <p><img class="irasto" src="https://s3.amazonaws.com/keras.io/img/keras-core/custom_training_loops.jpg" /></p> <hr /> <h2>A new distribution API for large-scale data parallelism and model parallelism.</h2> <p>The models we've been working with have been getting larger and larger, so we wanted to provide a Kerasic solution to the multi-device model sharding problem. The API we designed keeps the model definition, the training logic, and the sharding configuration entirely separate from each other, meaning that your models can be written as if they were going to run on a single device. You can then add arbitrary sharding configurations to arbitrary models when it's time to train them.</p> <p>Data parallelism (replicating a small model identically on multiple devices) can be handled in just two lines:</p> <p><img class="irasto" src="https://s3.amazonaws.com/keras.io/img/keras_3/keras_3_data_parallel.jpg" /></p> <p>Model parallelism lets you specify sharding layouts for model variables and intermediate output tensors, along multiple named dimensions. In the typical case, you would organize available devices as a 2D grid (called a <em>device mesh</em>), where the first dimension is used for data parallelism and the second dimension is used for model parallelism. You would then configure your model to be sharded along the model dimension and replicated along the data dimension.</p> <p>The API lets you configure the layout of every variable and every output tensor via regular expressions. This makes it easy to quickly specify the same layout for entire categories of variables.</p> <p><img class="irasto" src="https://s3.amazonaws.com/keras.io/img/keras_3/keras_3_model_parallel.jpg" /></p> <p>The new distribution API is intended to be multi-backend, but is only available for the JAX backend for the time being. TensorFlow and PyTorch support is coming soon. Get started with <a href="/guides/distribution/">this guide</a>!</p> <hr /> <h2>Pretrained models.</h2> <p>There's a wide range of pretrained models that you can start using today with Keras 3.</p> <p>All 40 Keras Applications models (the <code>keras.applications</code> namespace) are available in all backends. The vast array of pretrained models in <a href="https://keras.io/api/keras_cv/">KerasCV</a> and <a href="https://keras.io/api/keras_hub/">KerasHub</a> also work with all backends. This includes:</p> <ul> <li>BERT</li> <li>OPT</li> <li>Whisper</li> <li>T5</li> <li>StableDiffusion</li> <li>YOLOv8</li> <li>SegmentAnything</li> <li>etc.</li> </ul> <hr /> <h2>Support for cross-framework data pipelines with all backends.</h2> <p>Multi-framework ML also means multi-framework data loading and preprocessing. Keras 3 models can be trained using a wide range of data pipelines — regardless of whether you're using the JAX, PyTorch, or TensorFlow backends. It just works.</p> <ul> <li><code>tf.data.Dataset</code> pipelines: the reference for scalable production ML.</li> <li><code>torch.utils.data.DataLoader</code> objects.</li> <li>NumPy arrays and Pandas dataframes.</li> <li>Keras's own <code>keras.utils.PyDataset</code> objects.</li> </ul> <hr /> <h2>Progressive disclosure of complexity.</h2> <p><em>Progressive disclosure of complexity</em> is the design principle at the heart of the Keras API. Keras doesn't force you to follow a single "true" way of building and training models. Instead, it enables a wide range of different workflows, from the very high-level to the very low-level, corresponding to different user profiles.</p> <p>That means that you can start out with simple workflows — such as using <code>Sequential</code> and <code>Functional</code> models and training them with <code>fit()</code> — and when you need more flexibility, you can easily customize different components while reusing most of your prior code. As your needs become more specific, you don't suddenly fall off a complexity cliff and you don't need to switch to a different set of tools.</p> <p>We've brought this principle to all of our backends. For instance, you can customize what happens in your training loop while still leveraging the power of <code>fit()</code>, without having to write your own training loop from scratch — just by overriding the <code>train_step</code> method.</p> <p>Here's how it works in PyTorch and TensorFlow:</p> <p><img class="irasto" src="https://s3.amazonaws.com/keras.io/img/keras-core/customizing_fit.jpg" /></p> <p>And <a href="http://keras.io/guides/custom_train_step_in_jax/">here's the link</a> to the JAX version.</p> <hr /> <h2>A new stateless API for layers, models, metrics, and optimizers.</h2> <p>Do you enjoy <a href="https://en.wikipedia.org/wiki/Functional_programming">functional programming</a>? You're in for a treat.</p> <p>All stateful objects in Keras (i.e. objects that own numerical variables that get updated during training or evaluation) now have a stateless API, making it possible to use them in JAX functions (which are required to be fully stateless):</p> <ul> <li>All layers and models have a <code>stateless_call()</code> method which mirrors <code>__call__()</code>.</li> <li>All optimizers have a <code>stateless_apply()</code> method which mirrors <code>apply()</code>.</li> <li>All metrics have a <code>stateless_update_state()</code> method which mirrors <code>update_state()</code> and a <code>stateless_result()</code> method which mirrors <code>result()</code>.</li> </ul> <p>These methods have no side-effects whatsoever: they take as input the current value of the state variables of the target object, and return the update values as part of their outputs, e.g.:</p> <div class="codehilite"><pre><span></span><code><span class="n">outputs</span><span class="p">,</span> <span class="n">updated_non_trainable_variables</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">stateless_call</span><span class="p">(</span> <span class="n">trainable_variables</span><span class="p">,</span> <span class="n">non_trainable_variables</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="p">)</span> </code></pre></div> <p>You never have to implement these methods yourself — they're automatically available as long as you've implemented the stateful version (e.g. <code>call()</code> or <code>update_state()</code>).</p> <hr /> <h2>Moving from Keras 2 to Keras 3</h2> <p>Keras 3 is highly backwards compatible with Keras 2: it implements the full public API surface of Keras 2, with a limited number of exceptions, listed <a href="https://github.com/keras-team/keras/issues/18467">here</a>. Most users will not have to make any code change to start running their Keras scripts on Keras 3.</p> <p>Larger codebases are likely to require some code changes, since they are more likely to run into one of the exceptions listed above, and are more likely to have been using private APIs or deprecated APIs (<code>tf.compat.v1.keras</code> namespace, <code>experimental</code> namespace, <code>keras.src</code> private namespace). To help you move to Keras 3, we are releasing a complete <a href="/guides/migrating_to_keras_3/">migration guide</a> with quick fixes for all issues you might encounter.</p> <p>You also have the option to ignore the changes in Keras 3 and just keep using Keras 2 with TensorFlow — this can be a good option for projects that are not actively developed but need to keep running with updated dependencies. You have two possibilities:</p> <ol> <li>If you were accessing <code>keras</code> as a standalone package, just switch to using the Python package <code>tf_keras</code> instead, which you can install via <code>pip install tf_keras</code>. The code and API are wholly unchanged — it's Keras 2.15 with a different package name. We will keep fixing bugs in <code>tf_keras</code> and we will keep regularly releasing new versions. However, no new features or performance improvements will be added, since the package is now in maintenance mode.</li> <li>If you were accessing <code>keras</code> via <code>tf.keras</code>, there are no immediate changes until TensorFlow 2.16. TensorFlow 2.16+ will use Keras 3 by default. In TensorFlow 2.16+, to keep using Keras 2, you can first install <code>tf_keras</code>, and then export the environment variable <code>TF_USE_LEGACY_KERAS=1</code>. This will direct TensorFlow 2.16+ to resolve tf.keras to the locally-installed <code>tf_keras</code> package. Note that this may affect more than your own code, however: it will affect any package importing <code>tf.keras</code> in your Python process. To make sure your changes only affect your own code, you should use the <code>tf_keras</code> package. </li> </ol> <hr /> <h2>Enjoy the library!</h2> <p>We're excited for you to try out the new Keras and improve your workflows by leveraging multi-framework ML. Let us know how it goes: issues, points of friction, feature requests, or success stories — we're eager to hear from you!</p> <hr /> <h2>FAQ</h2> <h4>Q: Is Keras 3 compatible with legacy Keras 2?</h4> <p>Code developed with <code>tf.keras</code> can generally be run as-is with Keras 3 (with the TensorFlow backend). There's a limited number of incompatibilities you should be mindful of, all addressed in <a href="/guides/migrating_to_keras_3/">this migration guide</a>.</p> <p>When it comes to using APIs from <code>tf.keras</code> and Keras 3 side by side, that is <strong>not</strong> possible — they're different packages, running on entirely separate engines.</p> <h3>Q: Do pretrained models developed in legacy Keras 2 work with Keras 3?</h3> <p>Generally, yes. Any <code>tf.keras</code> model should work out of the box with Keras 3 with the TensorFlow backend (make sure to save it in the <code>.keras</code> v3 format). In addition, if the model only uses built-in Keras layers, then it will also work out of the box with Keras 3 with the JAX and PyTorch backends.</p> <p>If the model contains custom layers written using TensorFlow APIs, it is usually easy to convert the code to be backend-agnostic. For instance, it only took us a few hours to convert all 40 legacy <code>tf.keras</code> models from Keras Applications to be backend-agnostic.</p> <h3>Q: Can I save a Keras 3 model in one backend and reload it in another backend?</h3> <p>Yes, you can. There is no backend specialization in saved <code>.keras</code> files whatsoever. Your saved Keras models are framework-agnostic and can be reloaded with any backend.</p> <p>However, note that reloading a model that contains custom components with a different backend requires your custom components to be implemented using backend-agnostic APIs, e.g. <code>keras.ops</code>.</p> <h3>Q: Can I use Keras 3 components inside <code>tf.data</code> pipelines?</h3> <p>With the TensorFlow backend, Keras 3 is fully compatible with <code>tf.data</code> (e.g. you can <code>.map()</code> a <code>Sequential</code> model into a <code>tf.data</code> pipeline).</p> <p>With a different backend, Keras 3 has limited support for <code>tf.data</code>. You won't be able to <code>.map()</code> arbitrary layers or models into a <code>tf.data</code> pipeline. However, you will be able to use specific Keras 3 preprocessing layers with <code>tf.data</code>, such as <code>IntegerLookup</code> or <code>CategoryEncoding</code>.</p> <p>When it comes to using a <code>tf.data</code> pipeline (that does not use Keras) to feed your call to <code>.fit()</code>, <code>.evaluate()</code> or <code>.predict()</code> — that works out of the box with all backends.</p> <h3>Q: Do Keras 3 models behave the same when run with different backends?</h3> <p>Yes, numerics are identical across backends. However, keep in mind the following caveats:</p> <ul> <li>RNG behavior is different across different backends (even after seeding — your results will be deterministic in each backend but will differ across backends). So random weight initializations values and dropout values will differ across backends.</li> <li>Due to the nature of floating-point implementations, results are only identical up to <code>1e-7</code> precision in float32, per function execution. So when training a model for a long time, small numerical differences will accumulate and may end up resulting in noticeable numerical differences.</li> <li>Due to lack of support for average pooling with asymmetric padding in PyTorch, average pooling layers with <code>padding="same"</code> may result in different numerics on border rows/columns. This doesn't happen very often in practice — out of 40 Keras Applications vision models, only one was affected.</li> </ul> <h3>Q: Does Keras 3 support distributed training?</h3> <p>Data-parallel distribution is supported out of the box in JAX, TensorFlow, and PyTorch. Model parallel distribution is supported out of the box for JAX with the <code>keras.distribution</code> API.</p> <p><strong>With TensorFlow:</strong></p> <p>Keras 3 is compatible with <code>tf.distribute</code> — just open a Distribution Strategy scope and create / train your model within it. <a href="http://keras.io/guides/distributed_training_with_tensorflow/">Here's an example</a>.</p> <p><strong>With PyTorch:</strong></p> <p>Keras 3 is compatible with PyTorch's <code>DistributedDataParallel</code> utility. <a href="http://keras.io/guides/distributed_training_with_torch/">Here's an example</a>.</p> <p><strong>With JAX:</strong></p> <p>You can do both data parallel and model parallel distribution in JAX using the <code>keras.distribution</code> API. For instance, to do data parallel distribution, you only need the following code snippet:</p> <div class="codehilite"><pre><span></span><code><span class="n">distribution</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">distribution</span><span class="o">.</span><span class="n">DataParallel</span><span class="p">(</span><span class="n">devices</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">distribution</span><span class="o">.</span><span class="n">list_devices</span><span class="p">())</span> <span class="n">keras</span><span class="o">.</span><span class="n">distribution</span><span class="o">.</span><span class="n">set_distribution</span><span class="p">(</span><span class="n">distribution</span><span class="p">)</span> </code></pre></div> <p>For model parallel distribution, see <a href="/guides/distribution/">the following guide</a>.</p> <p>You can also distribute training yourself via JAX APIs such as <code>jax.sharding</code>. <a href="http://keras.io/guides/distributed_training_with_jax/">Here's an example</a>.</p> <h3>Q: Can my custom Keras layers be used in native PyTorch <code>Modules</code> or with Flax <code>Modules</code>?</h3> <p>If they are only written using Keras APIs (e.g. the <code>keras.ops</code> namespace), then yes, your Keras layers will work out of the box with native PyTorch and JAX code. In PyTorch, just use your Keras layer like any other PyTorch <code>Module</code>. In JAX, make sure to use the stateless layer API, i.e. <code>layer.stateless_call()</code>.</p> <h3>Q: Will you add more backends in the future? What about framework XYZ?</h3> <p>We're open to adding new backends as long as the target framework has a large user base or otherwise has some unique technical benefits to bring to the table. However, adding and maintaining a new backend is a large burden, so we're going to carefully consider each new backend candidate on a case by case basis, and we're not likely to add many new backends. We will not add any new frameworks that aren't yet well-established. We are now potentially considering adding a backend written in <a href="https://www.modular.com/mojo">Mojo</a>. If that's something you might find useful, please let the Mojo team know.</p> </div> </div> </div> </div> </body> </html>