CINXE.COM
Multi-GPU distributed training with JAX
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/guides/distributed_training_with_jax/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Multi-GPU distributed training with JAX"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Multi-GPU distributed training with JAX"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Multi-GPU distributed training with JAX</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link active" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-sublink" href="/guides/functional_api/">The Functional API</a> <a class="nav-sublink" href="/guides/sequential_model/">The Sequential model</a> <a class="nav-sublink" href="/guides/making_new_layers_and_models_via_subclassing/">Making new layers & models via subclassing</a> <a class="nav-sublink" href="/guides/training_with_built_in_methods/">Training & evaluation with the built-in methods</a> <a class="nav-sublink" href="/guides/custom_train_step_in_jax/">Customizing `fit()` with JAX</a> <a class="nav-sublink" href="/guides/custom_train_step_in_tensorflow/">Customizing `fit()` with TensorFlow</a> <a class="nav-sublink" href="/guides/custom_train_step_in_torch/">Customizing `fit()` with PyTorch</a> <a class="nav-sublink" href="/guides/writing_a_custom_training_loop_in_jax/">Writing a custom training loop in JAX</a> <a class="nav-sublink" href="/guides/writing_a_custom_training_loop_in_tensorflow/">Writing a custom training loop in TensorFlow</a> <a class="nav-sublink" href="/guides/writing_a_custom_training_loop_in_torch/">Writing a custom training loop in PyTorch</a> <a class="nav-sublink" href="/guides/serialization_and_saving/">Serialization & saving</a> <a class="nav-sublink" href="/guides/customizing_saving_and_serialization/">Customizing saving & serialization</a> <a class="nav-sublink" href="/guides/writing_your_own_callbacks/">Writing your own callbacks</a> <a class="nav-sublink" href="/guides/transfer_learning/">Transfer learning & fine-tuning</a> <a class="nav-sublink active" href="/guides/distributed_training_with_jax/">Distributed training with JAX</a> <a class="nav-sublink" href="/guides/distributed_training_with_tensorflow/">Distributed training with TensorFlow</a> <a class="nav-sublink" href="/guides/distributed_training_with_torch/">Distributed training with PyTorch</a> <a class="nav-sublink" href="/guides/distribution/">Distributed training with Keras 3</a> <a class="nav-sublink" href="/guides/migrating_to_keras_3/">Migrating Keras 2 code to Keras 3</a> <a class="nav-sublink" href="/guides/keras_tuner/">Hyperparameter Tuning</a> <a class="nav-sublink" href="/guides/keras_cv/">KerasCV</a> <a class="nav-sublink" href="/guides/keras_nlp/">KerasNLP</a> <a class="nav-sublink" href="/guides/keras_hub/">KerasHub</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparameter Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> <a class="nav-link" href="/keras_cv/" role="tab" aria-selected="">KerasCV: Computer Vision Workflows</a> <a class="nav-link" href="/keras_nlp/" role="tab" aria-selected="">KerasNLP: Natural Language Workflows</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/guides/'>Developer guides</a> / Multi-GPU distributed training with JAX </div> <div class='k-content'> <h1 id="multigpu-distributed-training-with-jax">Multi-GPU distributed training with JAX</h1> <p><strong>Author:</strong> <a href="https://twitter.com/fchollet">fchollet</a><br> <strong>Date created:</strong> 2023/07/11<br> <strong>Last modified:</strong> 2023/07/11<br> <strong>Description:</strong> Guide to multi-GPU/TPU training for Keras models with JAX.</p> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/guides/ipynb/distributed_training_with_jax.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/guides/distributed_training_with_jax.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="introduction">Introduction</h2> <p>There are generally two ways to distribute computation across multiple devices:</p> <p><strong>Data parallelism</strong>, where a single model gets replicated on multiple devices or multiple machines. Each of them processes different batches of data, then they merge their results. There exist many variants of this setup, that differ in how the different model replicas merge results, in whether they stay in sync at every batch or whether they are more loosely coupled, etc.</p> <p><strong>Model parallelism</strong>, where different parts of a single model run on different devices, processing a single batch of data together. This works best with models that have a naturally-parallel architecture, such as models that feature multiple branches.</p> <p>This guide focuses on data parallelism, in particular <strong>synchronous data parallelism</strong>, where the different replicas of the model stay in sync after each batch they process. Synchronicity keeps the model convergence behavior identical to what you would see for single-device training.</p> <p>Specifically, this guide teaches you how to use <code>jax.sharding</code> APIs to train Keras models, with minimal changes to your code, on multiple GPUs or TPUS (typically 2 to 16) installed on a single machine (single host, multi-device training). This is the most common setup for researchers and small-scale industry workflows.</p> <hr /> <h2 id="setup">Setup</h2> <p>Let's start by defining the function that creates the model that we will train, and the function that creates the dataset we will train on (MNIST in this case).</p> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">os</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s2">"KERAS_BACKEND"</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"jax"</span> <span class="kn">import</span> <span class="nn">jax</span> <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> <span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="nn">tf</span> <span class="kn">import</span> <span class="nn">keras</span> <span class="kn">from</span> <span class="nn">jax.experimental</span> <span class="kn">import</span> <span class="n">mesh_utils</span> <span class="kn">from</span> <span class="nn">jax.sharding</span> <span class="kn">import</span> <span class="n">Mesh</span> <span class="kn">from</span> <span class="nn">jax.sharding</span> <span class="kn">import</span> <span class="n">NamedSharding</span> <span class="kn">from</span> <span class="nn">jax.sharding</span> <span class="kn">import</span> <span class="n">PartitionSpec</span> <span class="k">as</span> <span class="n">P</span> <span class="k">def</span> <span class="nf">get_model</span><span class="p">():</span> <span class="c1"># Make a simple convnet with batch normalization and dropout.</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">28</span><span class="p">,</span> <span class="mi">28</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Rescaling</span><span class="p">(</span><span class="mf">1.0</span> <span class="o">/</span> <span class="mf">255.0</span><span class="p">)(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="n">filters</span><span class="o">=</span><span class="mi">12</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"same"</span><span class="p">,</span> <span class="n">use_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)(</span> <span class="n">x</span> <span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">BatchNormalization</span><span class="p">(</span><span class="n">scale</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">center</span><span class="o">=</span><span class="kc">True</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span> <span class="n">filters</span><span class="o">=</span><span class="mi">24</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">6</span><span class="p">,</span> <span class="n">use_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">strides</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">BatchNormalization</span><span class="p">(</span><span class="n">scale</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">center</span><span class="o">=</span><span class="kc">True</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span> <span class="n">filters</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">6</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"same"</span><span class="p">,</span> <span class="n">strides</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"large_k"</span><span class="p">,</span> <span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">BatchNormalization</span><span class="p">(</span><span class="n">scale</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">center</span><span class="o">=</span><span class="kc">True</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">GlobalAveragePooling2D</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">256</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="mf">0.5</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">10</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="p">)</span> <span class="k">return</span> <span class="n">model</span> <span class="k">def</span> <span class="nf">get_datasets</span><span class="p">():</span> <span class="c1"># Load the data and split it between train and test sets</span> <span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">),</span> <span class="p">(</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">)</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">datasets</span><span class="o">.</span><span class="n">mnist</span><span class="o">.</span><span class="n">load_data</span><span class="p">()</span> <span class="c1"># Scale images to the [0, 1] range</span> <span class="n">x_train</span> <span class="o">=</span> <span class="n">x_train</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">"float32"</span><span class="p">)</span> <span class="n">x_test</span> <span class="o">=</span> <span class="n">x_test</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">"float32"</span><span class="p">)</span> <span class="c1"># Make sure images have shape (28, 28, 1)</span> <span class="n">x_train</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="n">x_test</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">x_test</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"x_train shape:"</span><span class="p">,</span> <span class="n">x_train</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="n">x_train</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="s2">"train samples"</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="n">x_test</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="s2">"test samples"</span><span class="p">)</span> <span class="c1"># Create TF Datasets</span> <span class="n">train_data</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">from_tensor_slices</span><span class="p">((</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">))</span> <span class="n">eval_data</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">from_tensor_slices</span><span class="p">((</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">))</span> <span class="k">return</span> <span class="n">train_data</span><span class="p">,</span> <span class="n">eval_data</span> </code></pre></div> <hr /> <h2 id="singlehost-multidevice-synchronous-training">Single-host, multi-device synchronous training</h2> <p>In this setup, you have one machine with several GPUs or TPUs on it (typically 2 to 16). Each device will run a copy of your model (called a <strong>replica</strong>). For simplicity, in what follows, we'll assume we're dealing with 8 GPUs, at no loss of generality.</p> <p><strong>How it works</strong></p> <p>At each step of training:</p> <ul> <li>The current batch of data (called <strong>global batch</strong>) is split into 8 different sub-batches (called <strong>local batches</strong>). For instance, if the global batch has 512 samples, each of the 8 local batches will have 64 samples.</li> <li>Each of the 8 replicas independently processes a local batch: they run a forward pass, then a backward pass, outputting the gradient of the weights with respect to the loss of the model on the local batch.</li> <li>The weight updates originating from local gradients are efficiently merged across the 8 replicas. Because this is done at the end of every step, the replicas always stay in sync.</li> </ul> <p>In practice, the process of synchronously updating the weights of the model replicas is handled at the level of each individual weight variable. This is done through a using a <code>jax.sharding.NamedSharding</code> that is configured to replicate the variables.</p> <p><strong>How to use it</strong></p> <p>To do single-host, multi-device synchronous training with a Keras model, you would use the <code>jax.sharding</code> features. Here's how it works:</p> <ul> <li>We first create a device mesh using <code>mesh_utils.create_device_mesh</code>.</li> <li>We use <code>jax.sharding.Mesh</code>, <code>jax.sharding.NamedSharding</code> and <code>jax.sharding.PartitionSpec</code> to define how to partition JAX arrays. - We specify that we want to replicate the model and optimizer variables across all devices by using a spec with no axis. - We specify that we want to shard the data across devices by using a spec that splits along the batch dimension.</li> <li>We use <code>jax.device_put</code> to replicate the model and optimizer variables across devices. This happens once at the beginning.</li> <li>In the training loop, for each batch that we process, we use <code>jax.device_put</code> to split the batch across devices before invoking the train step.</li> </ul> <p>Here's the flow, where each step is split into its own utility function:</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Config</span> <span class="n">num_epochs</span> <span class="o">=</span> <span class="mi">2</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="mi">64</span> <span class="n">train_data</span><span class="p">,</span> <span class="n">eval_data</span> <span class="o">=</span> <span class="n">get_datasets</span><span class="p">()</span> <span class="n">train_data</span> <span class="o">=</span> <span class="n">train_data</span><span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">drop_remainder</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">get_model</span><span class="p">()</span> <span class="n">optimizer</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="mf">1e-3</span><span class="p">)</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">SparseCategoricalCrossentropy</span><span class="p">(</span><span class="n">from_logits</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="c1"># Initialize all state with .build()</span> <span class="p">(</span><span class="n">one_batch</span><span class="p">,</span> <span class="n">one_batch_labels</span><span class="p">)</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="nb">iter</span><span class="p">(</span><span class="n">train_data</span><span class="p">))</span> <span class="n">model</span><span class="o">.</span><span class="n">build</span><span class="p">(</span><span class="n">one_batch</span><span class="p">)</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">build</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">trainable_variables</span><span class="p">)</span> <span class="c1"># This is the loss function that will be differentiated.</span> <span class="c1"># Keras provides a pure functional forward pass: model.stateless_call</span> <span class="k">def</span> <span class="nf">compute_loss</span><span class="p">(</span><span class="n">trainable_variables</span><span class="p">,</span> <span class="n">non_trainable_variables</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span> <span class="n">y_pred</span><span class="p">,</span> <span class="n">updated_non_trainable_variables</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">stateless_call</span><span class="p">(</span> <span class="n">trainable_variables</span><span class="p">,</span> <span class="n">non_trainable_variables</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">True</span> <span class="p">)</span> <span class="n">loss_value</span> <span class="o">=</span> <span class="n">loss</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">)</span> <span class="k">return</span> <span class="n">loss_value</span><span class="p">,</span> <span class="n">updated_non_trainable_variables</span> <span class="c1"># Function to compute gradients</span> <span class="n">compute_gradients</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">value_and_grad</span><span class="p">(</span><span class="n">compute_loss</span><span class="p">,</span> <span class="n">has_aux</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="c1"># Training step, Keras provides a pure functional optimizer.stateless_apply</span> <span class="nd">@jax</span><span class="o">.</span><span class="n">jit</span> <span class="k">def</span> <span class="nf">train_step</span><span class="p">(</span><span class="n">train_state</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span> <span class="n">trainable_variables</span><span class="p">,</span> <span class="n">non_trainable_variables</span><span class="p">,</span> <span class="n">optimizer_variables</span> <span class="o">=</span> <span class="n">train_state</span> <span class="p">(</span><span class="n">loss_value</span><span class="p">,</span> <span class="n">non_trainable_variables</span><span class="p">),</span> <span class="n">grads</span> <span class="o">=</span> <span class="n">compute_gradients</span><span class="p">(</span> <span class="n">trainable_variables</span><span class="p">,</span> <span class="n">non_trainable_variables</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span> <span class="p">)</span> <span class="n">trainable_variables</span><span class="p">,</span> <span class="n">optimizer_variables</span> <span class="o">=</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">stateless_apply</span><span class="p">(</span> <span class="n">optimizer_variables</span><span class="p">,</span> <span class="n">grads</span><span class="p">,</span> <span class="n">trainable_variables</span> <span class="p">)</span> <span class="k">return</span> <span class="n">loss_value</span><span class="p">,</span> <span class="p">(</span> <span class="n">trainable_variables</span><span class="p">,</span> <span class="n">non_trainable_variables</span><span class="p">,</span> <span class="n">optimizer_variables</span><span class="p">,</span> <span class="p">)</span> <span class="c1"># Replicate the model and optimizer variable on all devices</span> <span class="k">def</span> <span class="nf">get_replicated_train_state</span><span class="p">(</span><span class="n">devices</span><span class="p">):</span> <span class="c1"># All variables will be replicated on all devices</span> <span class="n">var_mesh</span> <span class="o">=</span> <span class="n">Mesh</span><span class="p">(</span><span class="n">devices</span><span class="p">,</span> <span class="n">axis_names</span><span class="o">=</span><span class="p">(</span><span class="s2">"_"</span><span class="p">))</span> <span class="c1"># In NamedSharding, axes not mentioned are replicated (all axes here)</span> <span class="n">var_replication</span> <span class="o">=</span> <span class="n">NamedSharding</span><span class="p">(</span><span class="n">var_mesh</span><span class="p">,</span> <span class="n">P</span><span class="p">())</span> <span class="c1"># Apply the distribution settings to the model variables</span> <span class="n">trainable_variables</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">device_put</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">trainable_variables</span><span class="p">,</span> <span class="n">var_replication</span><span class="p">)</span> <span class="n">non_trainable_variables</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">device_put</span><span class="p">(</span> <span class="n">model</span><span class="o">.</span><span class="n">non_trainable_variables</span><span class="p">,</span> <span class="n">var_replication</span> <span class="p">)</span> <span class="n">optimizer_variables</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">device_put</span><span class="p">(</span><span class="n">optimizer</span><span class="o">.</span><span class="n">variables</span><span class="p">,</span> <span class="n">var_replication</span><span class="p">)</span> <span class="c1"># Combine all state in a tuple</span> <span class="k">return</span> <span class="p">(</span><span class="n">trainable_variables</span><span class="p">,</span> <span class="n">non_trainable_variables</span><span class="p">,</span> <span class="n">optimizer_variables</span><span class="p">)</span> <span class="n">num_devices</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">local_devices</span><span class="p">())</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Running on </span><span class="si">{</span><span class="n">num_devices</span><span class="si">}</span><span class="s2"> devices: </span><span class="si">{</span><span class="n">jax</span><span class="o">.</span><span class="n">local_devices</span><span class="p">()</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> <span class="n">devices</span> <span class="o">=</span> <span class="n">mesh_utils</span><span class="o">.</span><span class="n">create_device_mesh</span><span class="p">((</span><span class="n">num_devices</span><span class="p">,))</span> <span class="c1"># Data will be split along the batch axis</span> <span class="n">data_mesh</span> <span class="o">=</span> <span class="n">Mesh</span><span class="p">(</span><span class="n">devices</span><span class="p">,</span> <span class="n">axis_names</span><span class="o">=</span><span class="p">(</span><span class="s2">"batch"</span><span class="p">,))</span> <span class="c1"># naming axes of the mesh</span> <span class="n">data_sharding</span> <span class="o">=</span> <span class="n">NamedSharding</span><span class="p">(</span> <span class="n">data_mesh</span><span class="p">,</span> <span class="n">P</span><span class="p">(</span> <span class="s2">"batch"</span><span class="p">,</span> <span class="p">),</span> <span class="p">)</span> <span class="c1"># naming axes of the sharded partition</span> <span class="c1"># Display data sharding</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="nb">iter</span><span class="p">(</span><span class="n">train_data</span><span class="p">))</span> <span class="n">sharded_x</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">device_put</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">data_sharding</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Data sharding"</span><span class="p">)</span> <span class="n">jax</span><span class="o">.</span><span class="n">debug</span><span class="o">.</span><span class="n">visualize_array_sharding</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">numpy</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">sharded_x</span><span class="p">,</span> <span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">28</span> <span class="o">*</span> <span class="mi">28</span><span class="p">]))</span> <span class="n">train_state</span> <span class="o">=</span> <span class="n">get_replicated_train_state</span><span class="p">(</span><span class="n">devices</span><span class="p">)</span> <span class="c1"># Custom training loop</span> <span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_epochs</span><span class="p">):</span> <span class="n">data_iter</span> <span class="o">=</span> <span class="nb">iter</span><span class="p">(</span><span class="n">train_data</span><span class="p">)</span> <span class="k">for</span> <span class="n">data</span> <span class="ow">in</span> <span class="n">data_iter</span><span class="p">:</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">data</span> <span class="n">sharded_x</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">device_put</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">data_sharding</span><span class="p">)</span> <span class="n">loss_value</span><span class="p">,</span> <span class="n">train_state</span> <span class="o">=</span> <span class="n">train_step</span><span class="p">(</span><span class="n">train_state</span><span class="p">,</span> <span class="n">sharded_x</span><span class="p">,</span> <span class="n">y</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Epoch"</span><span class="p">,</span> <span class="n">epoch</span><span class="p">,</span> <span class="s2">"loss:"</span><span class="p">,</span> <span class="n">loss_value</span><span class="p">)</span> <span class="c1"># Post-processing model state update to write them back into the model</span> <span class="n">trainable_variables</span><span class="p">,</span> <span class="n">non_trainable_variables</span><span class="p">,</span> <span class="n">optimizer_variables</span> <span class="o">=</span> <span class="n">train_state</span> <span class="k">for</span> <span class="n">variable</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">trainable_variables</span><span class="p">,</span> <span class="n">trainable_variables</span><span class="p">):</span> <span class="n">variable</span><span class="o">.</span><span class="n">assign</span><span class="p">(</span><span class="n">value</span><span class="p">)</span> <span class="k">for</span> <span class="n">variable</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">non_trainable_variables</span><span class="p">,</span> <span class="n">non_trainable_variables</span><span class="p">):</span> <span class="n">variable</span><span class="o">.</span><span class="n">assign</span><span class="p">(</span><span class="n">value</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>x_train shape: (60000, 28, 28, 1) 60000 train samples 10000 test samples Running on 1 devices: [CpuDevice(id=0)] Data sharding </code></pre></div> </div> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> CPU 0 </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span> </pre> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 0 loss: 0.28599858 Epoch 1 loss: 0.23666474 </code></pre></div> </div> <p>That's it!</p> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#multigpu-distributed-training-with-jax'>Multi-GPU distributed training with JAX</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#setup'>Setup</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#singlehost-multidevice-synchronous-training'>Single-host, multi-device synchronous training</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>