CINXE.COM
Multi-GPU distributed training with PyTorch
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/guides/distributed_training_with_torch/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Multi-GPU distributed training with PyTorch"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Multi-GPU distributed training with PyTorch"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Multi-GPU distributed training with PyTorch</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link active" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-sublink" href="/guides/functional_api/">The Functional API</a> <a class="nav-sublink" href="/guides/sequential_model/">The Sequential model</a> <a class="nav-sublink" href="/guides/making_new_layers_and_models_via_subclassing/">Making new layers & models via subclassing</a> <a class="nav-sublink" href="/guides/training_with_built_in_methods/">Training & evaluation with the built-in methods</a> <a class="nav-sublink" href="/guides/custom_train_step_in_jax/">Customizing `fit()` with JAX</a> <a class="nav-sublink" href="/guides/custom_train_step_in_tensorflow/">Customizing `fit()` with TensorFlow</a> <a class="nav-sublink" href="/guides/custom_train_step_in_torch/">Customizing `fit()` with PyTorch</a> <a class="nav-sublink" href="/guides/writing_a_custom_training_loop_in_jax/">Writing a custom training loop in JAX</a> <a class="nav-sublink" href="/guides/writing_a_custom_training_loop_in_tensorflow/">Writing a custom training loop in TensorFlow</a> <a class="nav-sublink" href="/guides/writing_a_custom_training_loop_in_torch/">Writing a custom training loop in PyTorch</a> <a class="nav-sublink" href="/guides/serialization_and_saving/">Serialization & saving</a> <a class="nav-sublink" href="/guides/customizing_saving_and_serialization/">Customizing saving & serialization</a> <a class="nav-sublink" href="/guides/writing_your_own_callbacks/">Writing your own callbacks</a> <a class="nav-sublink" href="/guides/transfer_learning/">Transfer learning & fine-tuning</a> <a class="nav-sublink" href="/guides/distributed_training_with_jax/">Distributed training with JAX</a> <a class="nav-sublink" href="/guides/distributed_training_with_tensorflow/">Distributed training with TensorFlow</a> <a class="nav-sublink active" href="/guides/distributed_training_with_torch/">Distributed training with PyTorch</a> <a class="nav-sublink" href="/guides/distribution/">Distributed training with Keras 3</a> <a class="nav-sublink" href="/guides/migrating_to_keras_3/">Migrating Keras 2 code to Keras 3</a> <a class="nav-sublink" href="/guides/keras_tuner/">Hyperparameter Tuning</a> <a class="nav-sublink" href="/guides/keras_cv/">KerasCV</a> <a class="nav-sublink" href="/guides/keras_nlp/">KerasNLP</a> <a class="nav-sublink" href="/guides/keras_hub/">KerasHub</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparameter Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> <a class="nav-link" href="/keras_cv/" role="tab" aria-selected="">KerasCV: Computer Vision Workflows</a> <a class="nav-link" href="/keras_nlp/" role="tab" aria-selected="">KerasNLP: Natural Language Workflows</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/guides/'>Developer guides</a> / Multi-GPU distributed training with PyTorch </div> <div class='k-content'> <h1 id="multigpu-distributed-training-with-pytorch">Multi-GPU distributed training with PyTorch</h1> <p><strong>Author:</strong> <a href="https://twitter.com/fchollet">fchollet</a><br> <strong>Date created:</strong> 2023/06/29<br> <strong>Last modified:</strong> 2023/06/29<br> <strong>Description:</strong> Guide to multi-GPU training for Keras models with PyTorch.</p> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/guides/ipynb/distributed_training_with_torch.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/guides/distributed_training_with_torch.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="introduction">Introduction</h2> <p>There are generally two ways to distribute computation across multiple devices:</p> <p><strong>Data parallelism</strong>, where a single model gets replicated on multiple devices or multiple machines. Each of them processes different batches of data, then they merge their results. There exist many variants of this setup, that differ in how the different model replicas merge results, in whether they stay in sync at every batch or whether they are more loosely coupled, etc.</p> <p><strong>Model parallelism</strong>, where different parts of a single model run on different devices, processing a single batch of data together. This works best with models that have a naturally-parallel architecture, such as models that feature multiple branches.</p> <p>This guide focuses on data parallelism, in particular <strong>synchronous data parallelism</strong>, where the different replicas of the model stay in sync after each batch they process. Synchronicity keeps the model convergence behavior identical to what you would see for single-device training.</p> <p>Specifically, this guide teaches you how to use PyTorch's <code>DistributedDataParallel</code> module wrapper to train Keras, with minimal changes to your code, on multiple GPUs (typically 2 to 16) installed on a single machine (single host, multi-device training). This is the most common setup for researchers and small-scale industry workflows.</p> <hr /> <h2 id="setup">Setup</h2> <p>Let's start by defining the function that creates the model that we will train, and the function that creates the dataset we will train on (MNIST in this case).</p> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">os</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s2">"KERAS_BACKEND"</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"torch"</span> <span class="kn">import</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> <span class="kn">import</span> <span class="nn">keras</span> <span class="k">def</span> <span class="nf">get_model</span><span class="p">():</span> <span class="c1"># Make a simple convnet with batch normalization and dropout.</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">28</span><span class="p">,</span> <span class="mi">28</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Rescaling</span><span class="p">(</span><span class="mf">1.0</span> <span class="o">/</span> <span class="mf">255.0</span><span class="p">)(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="n">filters</span><span class="o">=</span><span class="mi">12</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"same"</span><span class="p">,</span> <span class="n">use_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)(</span> <span class="n">x</span> <span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">BatchNormalization</span><span class="p">(</span><span class="n">scale</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">center</span><span class="o">=</span><span class="kc">True</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span> <span class="n">filters</span><span class="o">=</span><span class="mi">24</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">6</span><span class="p">,</span> <span class="n">use_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">strides</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">BatchNormalization</span><span class="p">(</span><span class="n">scale</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">center</span><span class="o">=</span><span class="kc">True</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span> <span class="n">filters</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">6</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"same"</span><span class="p">,</span> <span class="n">strides</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"large_k"</span><span class="p">,</span> <span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">BatchNormalization</span><span class="p">(</span><span class="n">scale</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">center</span><span class="o">=</span><span class="kc">True</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">GlobalAveragePooling2D</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">256</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="mf">0.5</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">10</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="p">)</span> <span class="k">return</span> <span class="n">model</span> <span class="k">def</span> <span class="nf">get_dataset</span><span class="p">():</span> <span class="c1"># Load the data and split it between train and test sets</span> <span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">),</span> <span class="p">(</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">)</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">datasets</span><span class="o">.</span><span class="n">mnist</span><span class="o">.</span><span class="n">load_data</span><span class="p">()</span> <span class="c1"># Scale images to the [0, 1] range</span> <span class="n">x_train</span> <span class="o">=</span> <span class="n">x_train</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">"float32"</span><span class="p">)</span> <span class="n">x_test</span> <span class="o">=</span> <span class="n">x_test</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">"float32"</span><span class="p">)</span> <span class="c1"># Make sure images have shape (28, 28, 1)</span> <span class="n">x_train</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="n">x_test</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">x_test</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"x_train shape:"</span><span class="p">,</span> <span class="n">x_train</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="c1"># Create a TensorDataset</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">TensorDataset</span><span class="p">(</span> <span class="n">torch</span><span class="o">.</span><span class="n">from_numpy</span><span class="p">(</span><span class="n">x_train</span><span class="p">),</span> <span class="n">torch</span><span class="o">.</span><span class="n">from_numpy</span><span class="p">(</span><span class="n">y_train</span><span class="p">)</span> <span class="p">)</span> <span class="k">return</span> <span class="n">dataset</span> </code></pre></div> <p>Next, let's define a simple PyTorch training loop that targets a GPU (note the calls to <code>.cuda()</code>).</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">train_model</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">dataloader</span><span class="p">,</span> <span class="n">num_epochs</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">,</span> <span class="n">loss_fn</span><span class="p">):</span> <span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_epochs</span><span class="p">):</span> <span class="n">running_loss</span> <span class="o">=</span> <span class="mf">0.0</span> <span class="n">running_loss_count</span> <span class="o">=</span> <span class="mi">0</span> <span class="k">for</span> <span class="n">batch_idx</span><span class="p">,</span> <span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">dataloader</span><span class="p">):</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">inputs</span><span class="o">.</span><span class="n">cuda</span><span class="p">(</span><span class="n">non_blocking</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="n">targets</span> <span class="o">=</span> <span class="n">targets</span><span class="o">.</span><span class="n">cuda</span><span class="p">(</span><span class="n">non_blocking</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="c1"># Forward pass</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">loss_fn</span><span class="p">(</span><span class="n">outputs</span><span class="p">,</span> <span class="n">targets</span><span class="p">)</span> <span class="c1"># Backward and optimize</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">zero_grad</span><span class="p">()</span> <span class="n">loss</span><span class="o">.</span><span class="n">backward</span><span class="p">()</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">step</span><span class="p">()</span> <span class="n">running_loss</span> <span class="o">+=</span> <span class="n">loss</span><span class="o">.</span><span class="n">item</span><span class="p">()</span> <span class="n">running_loss_count</span> <span class="o">+=</span> <span class="mi">1</span> <span class="c1"># Print loss statistics</span> <span class="nb">print</span><span class="p">(</span> <span class="sa">f</span><span class="s2">"Epoch </span><span class="si">{</span><span class="n">epoch</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="mi">1</span><span class="si">}</span><span class="s2">/</span><span class="si">{</span><span class="n">num_epochs</span><span class="si">}</span><span class="s2">, "</span> <span class="sa">f</span><span class="s2">"Loss: </span><span class="si">{</span><span class="n">running_loss</span><span class="w"> </span><span class="o">/</span><span class="w"> </span><span class="n">running_loss_count</span><span class="si">}</span><span class="s2">"</span> <span class="p">)</span> </code></pre></div> <hr /> <h2 id="singlehost-multidevice-synchronous-training">Single-host, multi-device synchronous training</h2> <p>In this setup, you have one machine with several GPUs on it (typically 2 to 16). Each device will run a copy of your model (called a <strong>replica</strong>). For simplicity, in what follows, we'll assume we're dealing with 8 GPUs, at no loss of generality.</p> <p><strong>How it works</strong></p> <p>At each step of training:</p> <ul> <li>The current batch of data (called <strong>global batch</strong>) is split into 8 different sub-batches (called <strong>local batches</strong>). For instance, if the global batch has 512 samples, each of the 8 local batches will have 64 samples.</li> <li>Each of the 8 replicas independently processes a local batch: they run a forward pass, then a backward pass, outputting the gradient of the weights with respect to the loss of the model on the local batch.</li> <li>The weight updates originating from local gradients are efficiently merged across the 8 replicas. Because this is done at the end of every step, the replicas always stay in sync.</li> </ul> <p>In practice, the process of synchronously updating the weights of the model replicas is handled at the level of each individual weight variable. This is done through a <strong>mirrored variable</strong> object.</p> <p><strong>How to use it</strong></p> <p>To do single-host, multi-device synchronous training with a Keras model, you would use the <code>torch.nn.parallel.DistributedDataParallel</code> module wrapper. Here's how it works:</p> <ul> <li>We use <code>torch.multiprocessing.start_processes</code> to start multiple Python processes, one per device. Each process will run the <code>per_device_launch_fn</code> function.</li> <li>The <code>per_device_launch_fn</code> function does the following: - It uses <code>torch.distributed.init_process_group</code> and <code>torch.cuda.set_device</code> to configure the device to be used for that process. - It uses <code>torch.utils.data.distributed.DistributedSampler</code> and <code>torch.utils.data.DataLoader</code> to turn our data into a distributed data loader. - It also uses <code>torch.nn.parallel.DistributedDataParallel</code> to turn our model into a distributed PyTorch module. - It then calls the <code>train_model</code> function.</li> <li>The <code>train_model</code> function will then run in each process, with the model using a separate device in each process.</li> </ul> <p>Here's the flow, where each step is split into its own utility function:</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Config</span> <span class="n">num_gpu</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">device_count</span><span class="p">()</span> <span class="n">num_epochs</span> <span class="o">=</span> <span class="mi">2</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="mi">64</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Running on </span><span class="si">{</span><span class="n">num_gpu</span><span class="si">}</span><span class="s2"> GPUs"</span><span class="p">)</span> <span class="k">def</span> <span class="nf">setup_device</span><span class="p">(</span><span class="n">current_gpu_index</span><span class="p">,</span> <span class="n">num_gpus</span><span class="p">):</span> <span class="c1"># Device setup</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s2">"MASTER_ADDR"</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"localhost"</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s2">"MASTER_PORT"</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"56492"</span> <span class="n">device</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s2">"cuda:</span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">current_gpu_index</span><span class="p">))</span> <span class="n">torch</span><span class="o">.</span><span class="n">distributed</span><span class="o">.</span><span class="n">init_process_group</span><span class="p">(</span> <span class="n">backend</span><span class="o">=</span><span class="s2">"nccl"</span><span class="p">,</span> <span class="n">init_method</span><span class="o">=</span><span class="s2">"env://"</span><span class="p">,</span> <span class="n">world_size</span><span class="o">=</span><span class="n">num_gpus</span><span class="p">,</span> <span class="n">rank</span><span class="o">=</span><span class="n">current_gpu_index</span><span class="p">,</span> <span class="p">)</span> <span class="n">torch</span><span class="o">.</span><span class="n">cuda</span><span class="o">.</span><span class="n">set_device</span><span class="p">(</span><span class="n">device</span><span class="p">)</span> <span class="k">def</span> <span class="nf">cleanup</span><span class="p">():</span> <span class="n">torch</span><span class="o">.</span><span class="n">distributed</span><span class="o">.</span><span class="n">destroy_process_group</span><span class="p">()</span> <span class="k">def</span> <span class="nf">prepare_dataloader</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">current_gpu_index</span><span class="p">,</span> <span class="n">num_gpus</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">):</span> <span class="n">sampler</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">distributed</span><span class="o">.</span><span class="n">DistributedSampler</span><span class="p">(</span> <span class="n">dataset</span><span class="p">,</span> <span class="n">num_replicas</span><span class="o">=</span><span class="n">num_gpus</span><span class="p">,</span> <span class="n">rank</span><span class="o">=</span><span class="n">current_gpu_index</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="p">)</span> <span class="n">dataloader</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">DataLoader</span><span class="p">(</span> <span class="n">dataset</span><span class="p">,</span> <span class="n">sampler</span><span class="o">=</span><span class="n">sampler</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="p">)</span> <span class="k">return</span> <span class="n">dataloader</span> <span class="k">def</span> <span class="nf">per_device_launch_fn</span><span class="p">(</span><span class="n">current_gpu_index</span><span class="p">,</span> <span class="n">num_gpu</span><span class="p">):</span> <span class="c1"># Setup the process groups</span> <span class="n">setup_device</span><span class="p">(</span><span class="n">current_gpu_index</span><span class="p">,</span> <span class="n">num_gpu</span><span class="p">)</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">get_dataset</span><span class="p">()</span> <span class="n">model</span> <span class="o">=</span> <span class="n">get_model</span><span class="p">()</span> <span class="c1"># prepare the dataloader</span> <span class="n">dataloader</span> <span class="o">=</span> <span class="n">prepare_dataloader</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">current_gpu_index</span><span class="p">,</span> <span class="n">num_gpu</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">)</span> <span class="c1"># Instantiate the torch optimizer</span> <span class="n">optimizer</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">parameters</span><span class="p">(),</span> <span class="n">lr</span><span class="o">=</span><span class="mf">1e-3</span><span class="p">)</span> <span class="c1"># Instantiate the torch loss function</span> <span class="n">loss_fn</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">CrossEntropyLoss</span><span class="p">()</span> <span class="c1"># Put model on device</span> <span class="n">model</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="n">current_gpu_index</span><span class="p">)</span> <span class="n">ddp_model</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">parallel</span><span class="o">.</span><span class="n">DistributedDataParallel</span><span class="p">(</span> <span class="n">model</span><span class="p">,</span> <span class="n">device_ids</span><span class="o">=</span><span class="p">[</span><span class="n">current_gpu_index</span><span class="p">],</span> <span class="n">output_device</span><span class="o">=</span><span class="n">current_gpu_index</span> <span class="p">)</span> <span class="n">train_model</span><span class="p">(</span><span class="n">ddp_model</span><span class="p">,</span> <span class="n">dataloader</span><span class="p">,</span> <span class="n">num_epochs</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">,</span> <span class="n">loss_fn</span><span class="p">)</span> <span class="n">cleanup</span><span class="p">()</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Running on 0 GPUs /opt/conda/envs/keras-torch/lib/python3.10/site-packages/torch/cuda/__init__.py:611: UserWarning: Can't initialize NVML warnings.warn("Can't initialize NVML") </code></pre></div> </div> <p>Time to start multiple processes:</p> <div class="codehilite"><pre><span></span><code><span class="k">if</span> <span class="vm">__name__</span> <span class="o">==</span> <span class="s2">"__main__"</span><span class="p">:</span> <span class="c1"># We use the "fork" method rather than "spawn" to support notebooks</span> <span class="n">torch</span><span class="o">.</span><span class="n">multiprocessing</span><span class="o">.</span><span class="n">start_processes</span><span class="p">(</span> <span class="n">per_device_launch_fn</span><span class="p">,</span> <span class="n">args</span><span class="o">=</span><span class="p">(</span><span class="n">num_gpu</span><span class="p">,),</span> <span class="n">nprocs</span><span class="o">=</span><span class="n">num_gpu</span><span class="p">,</span> <span class="n">join</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">start_method</span><span class="o">=</span><span class="s2">"fork"</span><span class="p">,</span> <span class="p">)</span> </code></pre></div> <p>That's it!</p> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#multigpu-distributed-training-with-pytorch'>Multi-GPU distributed training with PyTorch</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#setup'>Setup</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#singlehost-multidevice-synchronous-training'>Single-host, multi-device synchronous training</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>