CINXE.COM
Distributed training with Keras 3
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/guides/distribution/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Distributed training with Keras 3"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Distributed training with Keras 3"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Distributed training with Keras 3</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link active" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-sublink" href="/guides/functional_api/">The Functional API</a> <a class="nav-sublink" href="/guides/sequential_model/">The Sequential model</a> <a class="nav-sublink" href="/guides/making_new_layers_and_models_via_subclassing/">Making new layers & models via subclassing</a> <a class="nav-sublink" href="/guides/training_with_built_in_methods/">Training & evaluation with the built-in methods</a> <a class="nav-sublink" href="/guides/custom_train_step_in_jax/">Customizing `fit()` with JAX</a> <a class="nav-sublink" href="/guides/custom_train_step_in_tensorflow/">Customizing `fit()` with TensorFlow</a> <a class="nav-sublink" href="/guides/custom_train_step_in_torch/">Customizing `fit()` with PyTorch</a> <a class="nav-sublink" href="/guides/writing_a_custom_training_loop_in_jax/">Writing a custom training loop in JAX</a> <a class="nav-sublink" href="/guides/writing_a_custom_training_loop_in_tensorflow/">Writing a custom training loop in TensorFlow</a> <a class="nav-sublink" href="/guides/writing_a_custom_training_loop_in_torch/">Writing a custom training loop in PyTorch</a> <a class="nav-sublink" href="/guides/serialization_and_saving/">Serialization & saving</a> <a class="nav-sublink" href="/guides/customizing_saving_and_serialization/">Customizing saving & serialization</a> <a class="nav-sublink" href="/guides/writing_your_own_callbacks/">Writing your own callbacks</a> <a class="nav-sublink" href="/guides/transfer_learning/">Transfer learning & fine-tuning</a> <a class="nav-sublink" href="/guides/distributed_training_with_jax/">Distributed training with JAX</a> <a class="nav-sublink" href="/guides/distributed_training_with_tensorflow/">Distributed training with TensorFlow</a> <a class="nav-sublink" href="/guides/distributed_training_with_torch/">Distributed training with PyTorch</a> <a class="nav-sublink active" href="/guides/distribution/">Distributed training with Keras 3</a> <a class="nav-sublink" href="/guides/migrating_to_keras_3/">Migrating Keras 2 code to Keras 3</a> <a class="nav-link" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparam Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/guides/'>Developer guides</a> / Distributed training with Keras 3 </div> <div class='k-content'> <h1 id="distributed-training-with-keras-3">Distributed training with Keras 3</h1> <p><strong>Author:</strong> <a href="https://github.com/qlzh727">Qianli Zhu</a><br> <strong>Date created:</strong> 2023/11/07<br> <strong>Last modified:</strong> 2023/11/07<br> <strong>Description:</strong> Complete guide to the distribution API for multi-backend Keras.</p> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/guides/ipynb/distribution.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/guides/distribution.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="introduction">Introduction</h2> <p>The Keras distribution API is a new interface designed to facilitate distributed deep learning across a variety of backends like JAX, TensorFlow and PyTorch. This powerful API introduces a suite of tools enabling data and model parallelism, allowing for efficient scaling of deep learning models on multiple accelerators and hosts. Whether leveraging the power of GPUs or TPUs, the API provides a streamlined approach to initializing distributed environments, defining device meshes, and orchestrating the layout of tensors across computational resources. Through classes like <code>DataParallel</code> and <code>ModelParallel</code>, it abstracts the complexity involved in parallel computation, making it easier for developers to accelerate their machine learning workflows.</p> <hr /> <h2 id="how-it-works">How it works</h2> <p>The Keras distribution API provides a global programming model that allows developers to compose applications that operate on tensors in a global context (as if working with a single device) while automatically managing distribution across many devices. The API leverages the underlying framework (e.g. JAX) to distribute the program and tensors according to the sharding directives through a procedure called single program, multiple data (SPMD) expansion.</p> <p>By decoupling the application from sharding directives, the API enables running the same application on a single device, multiple devices, or even multiple clients, while preserving its global semantics.</p> <hr /> <h2 id="setup">Setup</h2> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">os</span> <span class="c1"># The distribution API is only implemented for the JAX backend for now.</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s2">"KERAS_BACKEND"</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"jax"</span> <span class="kn">import</span> <span class="nn">keras</span> <span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">layers</span> <span class="kn">import</span> <span class="nn">jax</span> <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> <span class="kn">from</span> <span class="nn">tensorflow</span> <span class="kn">import</span> <span class="n">data</span> <span class="k">as</span> <span class="n">tf_data</span> <span class="c1"># For dataset input.</span> </code></pre></div> <hr /> <h2 id="devicemesh-and-tensorlayout"><code>DeviceMesh</code> and <code>TensorLayout</code></h2> <p>The <a href="/api/distribution/layout_map#devicemesh-class"><code>keras.distribution.DeviceMesh</code></a> class in Keras distribution API represents a cluster of computational devices configured for distributed computation. It aligns with similar concepts in <a href="https://jax.readthedocs.io/en/latest/jax.sharding.html#jax.sharding.Mesh"><code>jax.sharding.Mesh</code></a> and <a href="https://www.tensorflow.org/api_docs/python/tf/experimental/dtensor/Mesh"><code>tf.dtensor.Mesh</code></a>, where it's used to map the physical devices to a logical mesh structure.</p> <p>The <code>TensorLayout</code> class then specifies how tensors are distributed across the <code>DeviceMesh</code>, detailing the sharding of tensors along specified axes that correspond to the names of the axes in the <code>DeviceMesh</code>.</p> <p>You can find more detailed concept explainers in the <a href="https://www.tensorflow.org/guide/dtensor_overview#dtensors_model_of_distributed_tensors">TensorFlow DTensor guide</a>.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Retrieve the local available gpu devices.</span> <span class="n">devices</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">devices</span><span class="p">(</span><span class="s2">"gpu"</span><span class="p">)</span> <span class="c1"># Assume it has 8 local GPUs.</span> <span class="c1"># Define a 2x4 device mesh with data and model parallel axes</span> <span class="n">mesh</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">distribution</span><span class="o">.</span><span class="n">DeviceMesh</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">4</span><span class="p">),</span> <span class="n">axis_names</span><span class="o">=</span><span class="p">[</span><span class="s2">"data"</span><span class="p">,</span> <span class="s2">"model"</span><span class="p">],</span> <span class="n">devices</span><span class="o">=</span><span class="n">devices</span> <span class="p">)</span> <span class="c1"># A 2D layout, which describes how a tensor is distributed across the</span> <span class="c1"># mesh. The layout can be visualized as a 2D grid with "model" as rows and</span> <span class="c1"># "data" as columns, and it is a [4, 2] grid when it mapped to the physical</span> <span class="c1"># devices on the mesh.</span> <span class="n">layout_2d</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">distribution</span><span class="o">.</span><span class="n">TensorLayout</span><span class="p">(</span><span class="n">axes</span><span class="o">=</span><span class="p">(</span><span class="s2">"model"</span><span class="p">,</span> <span class="s2">"data"</span><span class="p">),</span> <span class="n">device_mesh</span><span class="o">=</span><span class="n">mesh</span><span class="p">)</span> <span class="c1"># A 4D layout which could be used for data parallel of a image input.</span> <span class="n">replicated_layout_4d</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">distribution</span><span class="o">.</span><span class="n">TensorLayout</span><span class="p">(</span> <span class="n">axes</span><span class="o">=</span><span class="p">(</span><span class="s2">"data"</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">),</span> <span class="n">device_mesh</span><span class="o">=</span><span class="n">mesh</span> <span class="p">)</span> </code></pre></div> <hr /> <h2 id="distribution">Distribution</h2> <p>The <code>Distribution</code> class in Keras serves as a foundational abstract class designed for developing custom distribution strategies. It encapsulates the core logic needed to distribute a model's variables, input data, and intermediate computations across a device mesh. As an end user, you won't have to interact directly with this class, but its subclasses like <code>DataParallel</code> or <code>ModelParallel</code>.</p> <hr /> <h2 id="dataparallel">DataParallel</h2> <p>The <code>DataParallel</code> class in the Keras distribution API is designed for the data parallelism strategy in distributed training, where the model weights are replicated across all devices in the <code>DeviceMesh</code>, and each device processes a portion of the input data.</p> <p>Here is a sample usage of this class.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Create DataParallel with list of devices.</span> <span class="c1"># As a shortcut, the devices can be skipped,</span> <span class="c1"># and Keras will detect all local available devices.</span> <span class="c1"># E.g. data_parallel = DataParallel()</span> <span class="n">data_parallel</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">distribution</span><span class="o">.</span><span class="n">DataParallel</span><span class="p">(</span><span class="n">devices</span><span class="o">=</span><span class="n">devices</span><span class="p">)</span> <span class="c1"># Or you can choose to create DataParallel with a 1D `DeviceMesh`.</span> <span class="n">mesh_1d</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">distribution</span><span class="o">.</span><span class="n">DeviceMesh</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">8</span><span class="p">,),</span> <span class="n">axis_names</span><span class="o">=</span><span class="p">[</span><span class="s2">"data"</span><span class="p">],</span> <span class="n">devices</span><span class="o">=</span><span class="n">devices</span> <span class="p">)</span> <span class="n">data_parallel</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">distribution</span><span class="o">.</span><span class="n">DataParallel</span><span class="p">(</span><span class="n">device_mesh</span><span class="o">=</span><span class="n">mesh_1d</span><span class="p">)</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="mi">28</span><span class="p">,</span> <span class="mi">28</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">tf_data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">from_tensor_slices</span><span class="p">((</span><span class="n">inputs</span><span class="p">,</span> <span class="n">labels</span><span class="p">))</span><span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="mi">16</span><span class="p">)</span> <span class="c1"># Set the global distribution.</span> <span class="n">keras</span><span class="o">.</span><span class="n">distribution</span><span class="o">.</span><span class="n">set_distribution</span><span class="p">(</span><span class="n">data_parallel</span><span class="p">)</span> <span class="c1"># Note that all the model weights from here on are replicated to</span> <span class="c1"># all the devices of the `DeviceMesh`. This includes the RNG</span> <span class="c1"># state, optimizer states, metrics, etc. The dataset fed into `model.fit` or</span> <span class="c1"># `model.evaluate` will be split evenly on the batch dimension, and sent to</span> <span class="c1"># all the devices. You don't have to do any manual aggregration of losses,</span> <span class="c1"># since all the computation happens in a global context.</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">28</span><span class="p">,</span> <span class="mi">28</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="n">y</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Flatten</span><span class="p">()(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">y</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">units</span><span class="o">=</span><span class="mi">200</span><span class="p">,</span> <span class="n">use_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)(</span><span class="n">y</span><span class="p">)</span> <span class="n">y</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="mf">0.4</span><span class="p">)(</span><span class="n">y</span><span class="p">)</span> <span class="n">y</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">units</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"softmax"</span><span class="p">)(</span><span class="n">y</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="o">=</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="o">=</span><span class="n">y</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">loss</span><span class="o">=</span><span class="s2">"mse"</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">dataset</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 1/3 8/8 ━━━━━━━━━━━━━━━━━━━━ 8s 30ms/step - loss: 1.0116 Epoch 2/3 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.9237 Epoch 3/3 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 4ms/step - loss: 0.8736 8/8 ━━━━━━━━━━━━━━━━━━━━ 5s 5ms/step - loss: 0.8349 0.842325747013092 </code></pre></div> </div> <hr /> <h2 id="modelparallel-and-layoutmap"><code>ModelParallel</code> and <code>LayoutMap</code></h2> <p><code>ModelParallel</code> will be mostly useful when model weights are too large to fit on a single accelerator. This setting allows you to spit your model weights or activation tensors across all the devices on the <code>DeviceMesh</code>, and enable the horizontal scaling for the large models.</p> <p>Unlike the <code>DataParallel</code> model where all weights are fully replicated, the weights layout under <code>ModelParallel</code> usually need some customization for best performances. We introduce <code>LayoutMap</code> to let you specify the <code>TensorLayout</code> for any weights and intermediate tensors from global perspective.</p> <p><code>LayoutMap</code> is a dict-like object that maps a string to <code>TensorLayout</code> instances. It behaves differently from a normal Python dict in that the string key is treated as a regex when retrieving the value. The class allows you to define the naming schema of <code>TensorLayout</code> and then retrieve the corresponding <code>TensorLayout</code> instance. Typically, the key used to query is the <code>variable.path</code> attribute, which is the identifier of the variable. As a shortcut, a tuple or list of axis names is also allowed when inserting a value, and it will be converted to <code>TensorLayout</code>.</p> <p>The <code>LayoutMap</code> can also optionally contain a <code>DeviceMesh</code> to populate the <code>TensorLayout.device_mesh</code> if it is not set. When retrieving a layout with a key, and if there isn't an exact match, all existing keys in the layout map will be treated as regex and matched against the input key again. If there are multiple matches, a <code>ValueError</code> is raised. If no matches are found, <code>None</code> is returned.</p> <div class="codehilite"><pre><span></span><code><span class="n">mesh_2d</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">distribution</span><span class="o">.</span><span class="n">DeviceMesh</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">4</span><span class="p">),</span> <span class="n">axis_names</span><span class="o">=</span><span class="p">[</span><span class="s2">"data"</span><span class="p">,</span> <span class="s2">"model"</span><span class="p">],</span> <span class="n">devices</span><span class="o">=</span><span class="n">devices</span> <span class="p">)</span> <span class="n">layout_map</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">distribution</span><span class="o">.</span><span class="n">LayoutMap</span><span class="p">(</span><span class="n">mesh_2d</span><span class="p">)</span> <span class="c1"># The rule below means that for any weights that match with d1/kernel, it</span> <span class="c1"># will be sharded with model dimensions (4 devices), same for the d1/bias.</span> <span class="c1"># All other weights will be fully replicated.</span> <span class="n">layout_map</span><span class="p">[</span><span class="s2">"d1/kernel"</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="kc">None</span><span class="p">,</span> <span class="s2">"model"</span><span class="p">)</span> <span class="n">layout_map</span><span class="p">[</span><span class="s2">"d1/bias"</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="s2">"model"</span><span class="p">,)</span> <span class="c1"># You can also set the layout for the layer output like</span> <span class="n">layout_map</span><span class="p">[</span><span class="s2">"d2/output"</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="s2">"data"</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span> <span class="n">model_parallel</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">distribution</span><span class="o">.</span><span class="n">ModelParallel</span><span class="p">(</span><span class="n">layout_map</span><span class="p">,</span> <span class="n">batch_dim_name</span><span class="o">=</span><span class="s2">"data"</span><span class="p">)</span> <span class="n">keras</span><span class="o">.</span><span class="n">distribution</span><span class="o">.</span><span class="n">set_distribution</span><span class="p">(</span><span class="n">model_parallel</span><span class="p">)</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">28</span><span class="p">,</span> <span class="mi">28</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="n">y</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Flatten</span><span class="p">()(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">y</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">units</span><span class="o">=</span><span class="mi">200</span><span class="p">,</span> <span class="n">use_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"d1"</span><span class="p">)(</span><span class="n">y</span><span class="p">)</span> <span class="n">y</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="mf">0.4</span><span class="p">)(</span><span class="n">y</span><span class="p">)</span> <span class="n">y</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">units</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"softmax"</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"d2"</span><span class="p">)(</span><span class="n">y</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="o">=</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="o">=</span><span class="n">y</span><span class="p">)</span> <span class="c1"># The data will be sharded across the "data" dimension of the method, which</span> <span class="c1"># has 2 devices.</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">loss</span><span class="o">=</span><span class="s2">"mse"</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">dataset</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 1/3 /opt/conda/envs/keras-jax/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:761: UserWarning: Some donated buffers were not usable: ShapedArray(float32[784,50]). See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation. warnings.warn("Some donated buffers were not usable:" 8/8 ━━━━━━━━━━━━━━━━━━━━ 5s 8ms/step - loss: 1.0266 Epoch 2/3 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.9181 Epoch 3/3 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - loss: 0.8725 8/8 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - loss: 0.8381 0.8502610325813293 </code></pre></div> </div> <p>It is also easy to change the mesh structure to tune the computation between more data parallel or model parallel. You can do this by adjusting the shape of the mesh. And no changes are needed for any other code.</p> <div class="codehilite"><pre><span></span><code><span class="n">full_data_parallel_mesh</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">distribution</span><span class="o">.</span><span class="n">DeviceMesh</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">8</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">axis_names</span><span class="o">=</span><span class="p">[</span><span class="s2">"data"</span><span class="p">,</span> <span class="s2">"model"</span><span class="p">],</span> <span class="n">devices</span><span class="o">=</span><span class="n">devices</span> <span class="p">)</span> <span class="n">more_data_parallel_mesh</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">distribution</span><span class="o">.</span><span class="n">DeviceMesh</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span> <span class="n">axis_names</span><span class="o">=</span><span class="p">[</span><span class="s2">"data"</span><span class="p">,</span> <span class="s2">"model"</span><span class="p">],</span> <span class="n">devices</span><span class="o">=</span><span class="n">devices</span> <span class="p">)</span> <span class="n">more_model_parallel_mesh</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">distribution</span><span class="o">.</span><span class="n">DeviceMesh</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">4</span><span class="p">),</span> <span class="n">axis_names</span><span class="o">=</span><span class="p">[</span><span class="s2">"data"</span><span class="p">,</span> <span class="s2">"model"</span><span class="p">],</span> <span class="n">devices</span><span class="o">=</span><span class="n">devices</span> <span class="p">)</span> <span class="n">full_model_parallel_mesh</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">distribution</span><span class="o">.</span><span class="n">DeviceMesh</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">8</span><span class="p">),</span> <span class="n">axis_names</span><span class="o">=</span><span class="p">[</span><span class="s2">"data"</span><span class="p">,</span> <span class="s2">"model"</span><span class="p">],</span> <span class="n">devices</span><span class="o">=</span><span class="n">devices</span> <span class="p">)</span> </code></pre></div> <h3 id="further-reading">Further reading</h3> <ol> <li><a href="https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html">JAX Distributed arrays and automatic parallelization</a></li> <li><a href="https://jax.readthedocs.io/en/latest/jax.sharding.html">JAX sharding module</a></li> <li><a href="https://www.tensorflow.org/tutorials/distribute/dtensor_ml_tutorial">TensorFlow Distributed training with DTensors</a></li> <li><a href="https://www.tensorflow.org/guide/dtensor_overview">TensorFlow DTensor concepts</a></li> <li><a href="https://www.tensorflow.org/tutorials/distribute/dtensor_keras_tutorial">Using DTensors with tf.keras</a></li> </ol> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#distributed-training-with-keras-3'>Distributed training with Keras 3</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#how-it-works'>How it works</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#setup'>Setup</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#devicemesh-and-tensorlayout'><code>DeviceMesh</code> and <code>TensorLayout</code></a> </div> <div class='k-outline-depth-2'> ◆ <a href='#distribution'>Distribution</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#dataparallel'>DataParallel</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#modelparallel-and-layoutmap'><code>ModelParallel</code> and <code>LayoutMap</code></a> </div> <div class='k-outline-depth-3'> <a href='#further-reading'>Further reading</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>