CINXE.COM
Multi-GPU distributed training with TensorFlow
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/guides/distributed_training_with_tensorflow/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Multi-GPU distributed training with TensorFlow"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Multi-GPU distributed training with TensorFlow"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Multi-GPU distributed training with TensorFlow</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link active" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-sublink" href="/guides/functional_api/">The Functional API</a> <a class="nav-sublink" href="/guides/sequential_model/">The Sequential model</a> <a class="nav-sublink" href="/guides/making_new_layers_and_models_via_subclassing/">Making new layers & models via subclassing</a> <a class="nav-sublink" href="/guides/training_with_built_in_methods/">Training & evaluation with the built-in methods</a> <a class="nav-sublink" href="/guides/custom_train_step_in_jax/">Customizing `fit()` with JAX</a> <a class="nav-sublink" href="/guides/custom_train_step_in_tensorflow/">Customizing `fit()` with TensorFlow</a> <a class="nav-sublink" href="/guides/custom_train_step_in_torch/">Customizing `fit()` with PyTorch</a> <a class="nav-sublink" href="/guides/writing_a_custom_training_loop_in_jax/">Writing a custom training loop in JAX</a> <a class="nav-sublink" href="/guides/writing_a_custom_training_loop_in_tensorflow/">Writing a custom training loop in TensorFlow</a> <a class="nav-sublink" href="/guides/writing_a_custom_training_loop_in_torch/">Writing a custom training loop in PyTorch</a> <a class="nav-sublink" href="/guides/serialization_and_saving/">Serialization & saving</a> <a class="nav-sublink" href="/guides/customizing_saving_and_serialization/">Customizing saving & serialization</a> <a class="nav-sublink" href="/guides/writing_your_own_callbacks/">Writing your own callbacks</a> <a class="nav-sublink" href="/guides/transfer_learning/">Transfer learning & fine-tuning</a> <a class="nav-sublink" href="/guides/distributed_training_with_jax/">Distributed training with JAX</a> <a class="nav-sublink active" href="/guides/distributed_training_with_tensorflow/">Distributed training with TensorFlow</a> <a class="nav-sublink" href="/guides/distributed_training_with_torch/">Distributed training with PyTorch</a> <a class="nav-sublink" href="/guides/distribution/">Distributed training with Keras 3</a> <a class="nav-sublink" href="/guides/migrating_to_keras_3/">Migrating Keras 2 code to Keras 3</a> <a class="nav-link" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparam Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/guides/'>Developer guides</a> / Multi-GPU distributed training with TensorFlow </div> <div class='k-content'> <h1 id="multigpu-distributed-training-with-tensorflow">Multi-GPU distributed training with TensorFlow</h1> <p><strong>Author:</strong> <a href="https://twitter.com/fchollet">fchollet</a><br> <strong>Date created:</strong> 2020/04/28<br> <strong>Last modified:</strong> 2023/06/29<br> <strong>Description:</strong> Guide to multi-GPU training for Keras models with TensorFlow.</p> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/guides/ipynb/distributed_training_with_tensorflow.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/guides/distributed_training_with_tensorflow.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="introduction">Introduction</h2> <p>There are generally two ways to distribute computation across multiple devices:</p> <p><strong>Data parallelism</strong>, where a single model gets replicated on multiple devices or multiple machines. Each of them processes different batches of data, then they merge their results. There exist many variants of this setup, that differ in how the different model replicas merge results, in whether they stay in sync at every batch or whether they are more loosely coupled, etc.</p> <p><strong>Model parallelism</strong>, where different parts of a single model run on different devices, processing a single batch of data together. This works best with models that have a naturally-parallel architecture, such as models that feature multiple branches.</p> <p>This guide focuses on data parallelism, in particular <strong>synchronous data parallelism</strong>, where the different replicas of the model stay in sync after each batch they process. Synchronicity keeps the model convergence behavior identical to what you would see for single-device training.</p> <p>Specifically, this guide teaches you how to use the <a href="https://www.tensorflow.org/api_docs/python/tf/distribute"><code>tf.distribute</code></a> API to train Keras models on multiple GPUs, with minimal changes to your code, on multiple GPUs (typically 2 to 16) installed on a single machine (single host, multi-device training). This is the most common setup for researchers and small-scale industry workflows.</p> <hr /> <h2 id="setup">Setup</h2> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">os</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s2">"KERAS_BACKEND"</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"tensorflow"</span> <span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="nn">tf</span> <span class="kn">import</span> <span class="nn">keras</span> </code></pre></div> <hr /> <h2 id="singlehost-multidevice-synchronous-training">Single-host, multi-device synchronous training</h2> <p>In this setup, you have one machine with several GPUs on it (typically 2 to 16). Each device will run a copy of your model (called a <strong>replica</strong>). For simplicity, in what follows, we'll assume we're dealing with 8 GPUs, at no loss of generality.</p> <p><strong>How it works</strong></p> <p>At each step of training:</p> <ul> <li>The current batch of data (called <strong>global batch</strong>) is split into 8 different sub-batches (called <strong>local batches</strong>). For instance, if the global batch has 512 samples, each of the 8 local batches will have 64 samples.</li> <li>Each of the 8 replicas independently processes a local batch: they run a forward pass, then a backward pass, outputting the gradient of the weights with respect to the loss of the model on the local batch.</li> <li>The weight updates originating from local gradients are efficiently merged across the 8 replicas. Because this is done at the end of every step, the replicas always stay in sync.</li> </ul> <p>In practice, the process of synchronously updating the weights of the model replicas is handled at the level of each individual weight variable. This is done through a <strong>mirrored variable</strong> object.</p> <p><strong>How to use it</strong></p> <p>To do single-host, multi-device synchronous training with a Keras model, you would use the <a href="https://www.tensorflow.org/api_docs/python/tf/distribute/MirroredStrategy"><code>tf.distribute.MirroredStrategy</code> API</a>. Here's how it works:</p> <ul> <li>Instantiate a <code>MirroredStrategy</code>, optionally configuring which specific devices you want to use (by default the strategy will use all GPUs available).</li> <li>Use the strategy object to open a scope, and within this scope, create all the Keras objects you need that contain variables. Typically, that means <strong>creating & compiling the model</strong> inside the distribution scope. In some cases, the first call to <code>fit()</code> may also create variables, so it's a good idea to put your <code>fit()</code> call in the scope as well.</li> <li>Train the model via <code>fit()</code> as usual.</li> </ul> <p>Importantly, we recommend that you use <a href="https://www.tensorflow.org/api_docs/python/tf/data/Dataset"><code>tf.data.Dataset</code></a> objects to load data in a multi-device or distributed workflow.</p> <p>Schematically, it looks like this:</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Create a MirroredStrategy.</span> <span class="n">strategy</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">distribute</span><span class="o">.</span><span class="n">MirroredStrategy</span><span class="p">()</span> <span class="nb">print</span><span class="p">(</span><span class="s1">'Number of devices: </span><span class="si">{}</span><span class="s1">'</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">strategy</span><span class="o">.</span><span class="n">num_replicas_in_sync</span><span class="p">))</span> <span class="c1"># Open a strategy scope.</span> <span class="k">with</span> <span class="n">strategy</span><span class="o">.</span><span class="n">scope</span><span class="p">():</span> <span class="c1"># Everything that creates variables should be under the strategy scope.</span> <span class="c1"># In general this is only model construction & `compile()`.</span> <span class="n">model</span> <span class="o">=</span> <span class="n">Model</span><span class="p">(</span><span class="o">...</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="o">...</span><span class="p">)</span> <span class="c1"># Train the model on all available devices.</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">train_dataset</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="n">val_dataset</span><span class="p">,</span> <span class="o">...</span><span class="p">)</span> <span class="c1"># Test the model on all available devices.</span> <span class="n">model</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">test_dataset</span><span class="p">)</span> </code></pre></div> <p>Here's a simple end-to-end runnable example:</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">get_compiled_model</span><span class="p">():</span> <span class="c1"># Make a simple 2-layer densely-connected neural network.</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">784</span><span class="p">,))</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">256</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">256</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">10</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">Adam</span><span class="p">(),</span> <span class="n">loss</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">SparseCategoricalCrossentropy</span><span class="p">(</span><span class="n">from_logits</span><span class="o">=</span><span class="kc">True</span><span class="p">),</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">SparseCategoricalAccuracy</span><span class="p">()],</span> <span class="p">)</span> <span class="k">return</span> <span class="n">model</span> <span class="k">def</span> <span class="nf">get_dataset</span><span class="p">():</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="mi">32</span> <span class="n">num_val_samples</span> <span class="o">=</span> <span class="mi">10000</span> <span class="c1"># Return the MNIST dataset in the form of a [`tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset).</span> <span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">),</span> <span class="p">(</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">)</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">datasets</span><span class="o">.</span><span class="n">mnist</span><span class="o">.</span><span class="n">load_data</span><span class="p">()</span> <span class="c1"># Preprocess the data (these are Numpy arrays)</span> <span class="n">x_train</span> <span class="o">=</span> <span class="n">x_train</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">784</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">"float32"</span><span class="p">)</span> <span class="o">/</span> <span class="mi">255</span> <span class="n">x_test</span> <span class="o">=</span> <span class="n">x_test</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">784</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">"float32"</span><span class="p">)</span> <span class="o">/</span> <span class="mi">255</span> <span class="n">y_train</span> <span class="o">=</span> <span class="n">y_train</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">"float32"</span><span class="p">)</span> <span class="n">y_test</span> <span class="o">=</span> <span class="n">y_test</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">"float32"</span><span class="p">)</span> <span class="c1"># Reserve num_val_samples samples for validation</span> <span class="n">x_val</span> <span class="o">=</span> <span class="n">x_train</span><span class="p">[</span><span class="o">-</span><span class="n">num_val_samples</span><span class="p">:]</span> <span class="n">y_val</span> <span class="o">=</span> <span class="n">y_train</span><span class="p">[</span><span class="o">-</span><span class="n">num_val_samples</span><span class="p">:]</span> <span class="n">x_train</span> <span class="o">=</span> <span class="n">x_train</span><span class="p">[:</span><span class="o">-</span><span class="n">num_val_samples</span><span class="p">]</span> <span class="n">y_train</span> <span class="o">=</span> <span class="n">y_train</span><span class="p">[:</span><span class="o">-</span><span class="n">num_val_samples</span><span class="p">]</span> <span class="k">return</span> <span class="p">(</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">from_tensor_slices</span><span class="p">((</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">))</span><span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="n">batch_size</span><span class="p">),</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">from_tensor_slices</span><span class="p">((</span><span class="n">x_val</span><span class="p">,</span> <span class="n">y_val</span><span class="p">))</span><span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="n">batch_size</span><span class="p">),</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">from_tensor_slices</span><span class="p">((</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">))</span><span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="n">batch_size</span><span class="p">),</span> <span class="p">)</span> <span class="c1"># Create a MirroredStrategy.</span> <span class="n">strategy</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">distribute</span><span class="o">.</span><span class="n">MirroredStrategy</span><span class="p">()</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Number of devices: </span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">strategy</span><span class="o">.</span><span class="n">num_replicas_in_sync</span><span class="p">))</span> <span class="c1"># Open a strategy scope.</span> <span class="k">with</span> <span class="n">strategy</span><span class="o">.</span><span class="n">scope</span><span class="p">():</span> <span class="c1"># Everything that creates variables should be under the strategy scope.</span> <span class="c1"># In general this is only model construction & `compile()`.</span> <span class="n">model</span> <span class="o">=</span> <span class="n">get_compiled_model</span><span class="p">()</span> <span class="c1"># Train the model on all available devices.</span> <span class="n">train_dataset</span><span class="p">,</span> <span class="n">val_dataset</span><span class="p">,</span> <span class="n">test_dataset</span> <span class="o">=</span> <span class="n">get_dataset</span><span class="p">()</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">train_dataset</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="n">val_dataset</span><span class="p">)</span> <span class="c1"># Test the model on all available devices.</span> <span class="n">model</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">test_dataset</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',) Number of devices: 1 Epoch 1/2 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 7s 4ms/step - loss: 0.3830 - sparse_categorical_accuracy: 0.8884 - val_loss: 0.1361 - val_sparse_categorical_accuracy: 0.9574 Epoch 2/2 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 9s 3ms/step - loss: 0.1068 - sparse_categorical_accuracy: 0.9671 - val_loss: 0.0894 - val_sparse_categorical_accuracy: 0.9724 313/313 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 0.0988 - sparse_categorical_accuracy: 0.9673 </code></pre></div> </div> <hr /> <h2 id="using-callbacks-to-ensure-fault-tolerance">Using callbacks to ensure fault tolerance</h2> <p>When using distributed training, you should always make sure you have a strategy to recover from failure (fault tolerance). The simplest way to handle this is to pass <code>ModelCheckpoint</code> callback to <code>fit()</code>, to save your model at regular intervals (e.g. every 100 batches or every epoch). You can then restart training from your saved model.</p> <p>Here's a simple example:</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Prepare a directory to store all the checkpoints.</span> <span class="n">checkpoint_dir</span> <span class="o">=</span> <span class="s2">"./ckpt"</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">checkpoint_dir</span><span class="p">):</span> <span class="n">os</span><span class="o">.</span><span class="n">makedirs</span><span class="p">(</span><span class="n">checkpoint_dir</span><span class="p">)</span> <span class="k">def</span> <span class="nf">make_or_restore_model</span><span class="p">():</span> <span class="c1"># Either restore the latest model, or create a fresh one</span> <span class="c1"># if there is no checkpoint available.</span> <span class="n">checkpoints</span> <span class="o">=</span> <span class="p">[</span><span class="n">checkpoint_dir</span> <span class="o">+</span> <span class="s2">"/"</span> <span class="o">+</span> <span class="n">name</span> <span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">os</span><span class="o">.</span><span class="n">listdir</span><span class="p">(</span><span class="n">checkpoint_dir</span><span class="p">)]</span> <span class="k">if</span> <span class="n">checkpoints</span><span class="p">:</span> <span class="n">latest_checkpoint</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">checkpoints</span><span class="p">,</span> <span class="n">key</span><span class="o">=</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">getctime</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Restoring from"</span><span class="p">,</span> <span class="n">latest_checkpoint</span><span class="p">)</span> <span class="k">return</span> <span class="n">keras</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">load_model</span><span class="p">(</span><span class="n">latest_checkpoint</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Creating a new model"</span><span class="p">)</span> <span class="k">return</span> <span class="n">get_compiled_model</span><span class="p">()</span> <span class="k">def</span> <span class="nf">run_training</span><span class="p">(</span><span class="n">epochs</span><span class="o">=</span><span class="mi">1</span><span class="p">):</span> <span class="c1"># Create a MirroredStrategy.</span> <span class="n">strategy</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">distribute</span><span class="o">.</span><span class="n">MirroredStrategy</span><span class="p">()</span> <span class="c1"># Open a strategy scope and create/restore the model</span> <span class="k">with</span> <span class="n">strategy</span><span class="o">.</span><span class="n">scope</span><span class="p">():</span> <span class="n">model</span> <span class="o">=</span> <span class="n">make_or_restore_model</span><span class="p">()</span> <span class="n">callbacks</span> <span class="o">=</span> <span class="p">[</span> <span class="c1"># This callback saves a SavedModel every epoch</span> <span class="c1"># We include the current epoch in the folder name.</span> <span class="n">keras</span><span class="o">.</span><span class="n">callbacks</span><span class="o">.</span><span class="n">ModelCheckpoint</span><span class="p">(</span> <span class="n">filepath</span><span class="o">=</span><span class="n">checkpoint_dir</span> <span class="o">+</span> <span class="s2">"/ckpt-</span><span class="si">{epoch}</span><span class="s2">.keras"</span><span class="p">,</span> <span class="n">save_freq</span><span class="o">=</span><span class="s2">"epoch"</span><span class="p">,</span> <span class="p">)</span> <span class="p">]</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span> <span class="n">train_dataset</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="n">epochs</span><span class="p">,</span> <span class="n">callbacks</span><span class="o">=</span><span class="n">callbacks</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="n">val_dataset</span><span class="p">,</span> <span class="n">verbose</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="p">)</span> <span class="c1"># Running the first time creates the model</span> <span class="n">run_training</span><span class="p">(</span><span class="n">epochs</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># Calling the same function again will resume from where we left off</span> <span class="n">run_training</span><span class="p">(</span><span class="n">epochs</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',) Creating a new model 1563/1563 - 7s - 4ms/step - loss: 0.2275 - sparse_categorical_accuracy: 0.9320 - val_loss: 0.1373 - val_sparse_categorical_accuracy: 0.9571 INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',) Restoring from ./ckpt/ckpt-1.keras 1563/1563 - 6s - 4ms/step - loss: 0.0944 - sparse_categorical_accuracy: 0.9717 - val_loss: 0.0972 - val_sparse_categorical_accuracy: 0.9710 </code></pre></div> </div> <hr /> <h2 id="tfdata"><a href="https://www.tensorflow.org/api_docs/python/tf/data"><code>tf.data</code></a> performance tips</h2> <p>When doing distributed training, the efficiency with which you load data can often become critical. Here are a few tips to make sure your <a href="https://www.tensorflow.org/api_docs/python/tf/data"><code>tf.data</code></a> pipelines run as fast as possible.</p> <p><strong>Note about dataset batching</strong></p> <p>When creating your dataset, make sure it is batched with the global batch size. For instance, if each of your 8 GPUs is capable of running a batch of 64 samples, you call use a global batch size of 512.</p> <p><strong>Calling <code>dataset.cache()</code></strong></p> <p>If you call <code>.cache()</code> on a dataset, its data will be cached after running through the first iteration over the data. Every subsequent iteration will use the cached data. The cache can be in memory (default) or to a local file you specify.</p> <p>This can improve performance when:</p> <ul> <li>Your data is not expected to change from iteration to iteration</li> <li>You are reading data from a remote distributed filesystem</li> <li>You are reading data from local disk, but your data would fit in memory and your workflow is significantly IO-bound (e.g. reading & decoding image files).</li> </ul> <p><strong>Calling <code>dataset.prefetch(buffer_size)</code></strong></p> <p>You should almost always call <code>.prefetch(buffer_size)</code> after creating a dataset. It means your data pipeline will run asynchronously from your model, with new samples being preprocessed and stored in a buffer while the current batch samples are used to train the model. The next batch will be prefetched in GPU memory by the time the current batch is over.</p> <p>That's it!</p> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#multigpu-distributed-training-with-tensorflow'>Multi-GPU distributed training with TensorFlow</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#setup'>Setup</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#singlehost-multidevice-synchronous-training'>Single-host, multi-device synchronous training</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#using-callbacks-to-ensure-fault-tolerance'>Using callbacks to ensure fault tolerance</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#tfdata-performance-tips'><code>tf.data</code> performance tips</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>