CINXE.COM
Scale up Flax Modules on multiple devices
<!DOCTYPE html> <html lang="en" data-content_root="" > <head> <meta charset="utf-8" /> <meta name="viewport" content="width=device-width, initial-scale=1.0" /> <title>Scale up Flax Modules on multiple devices</title> <script data-cfasync="false"> document.documentElement.dataset.mode = localStorage.getItem("mode") || ""; document.documentElement.dataset.theme = localStorage.getItem("theme") || ""; </script> <!-- Loaded before other Sphinx assets --> <link href="../../_static/styles/theme.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" /> <link href="../../_static/styles/bootstrap.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" /> <link href="../../_static/styles/pydata-sphinx-theme.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" /> <link href="../../_static/vendor/fontawesome/6.5.2/css/all.min.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" /> <link rel="preload" as="font" type="font/woff2" crossorigin href="../../_static/vendor/fontawesome/6.5.2/webfonts/fa-solid-900.woff2" /> <link rel="preload" as="font" type="font/woff2" crossorigin href="../../_static/vendor/fontawesome/6.5.2/webfonts/fa-brands-400.woff2" /> <link rel="preload" as="font" type="font/woff2" crossorigin href="../../_static/vendor/fontawesome/6.5.2/webfonts/fa-regular-400.woff2" /> <link rel="stylesheet" type="text/css" href="../../_static/pygments.css" /> <link rel="stylesheet" type="text/css" href="../../_static/styles/sphinx-book-theme.css" /> <link rel="stylesheet" type="text/css" href="../../_static/mystnb.4510f1fc1dee50b3e5859aac5469c37c29e427902b24a333a5f9fcb2f0b3ac41.css" /> <link rel="stylesheet" type="text/css" href="../../_static/sphinx-design.5ea377869091fd0449014c60fc090103.min.css" /> <link rel="stylesheet" type="text/css" href="../../_static/css/flax_theme.css" /> <!-- Pre-loaded scripts that we'll load fully later --> <link rel="preload" as="script" href="../../_static/scripts/bootstrap.js?digest=dfe6caa3a7d634c4db9b" /> <link rel="preload" as="script" href="../../_static/scripts/pydata-sphinx-theme.js?digest=dfe6caa3a7d634c4db9b" /> <script src="../../_static/vendor/fontawesome/6.5.2/js/all.min.js?digest=dfe6caa3a7d634c4db9b"></script> <script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script> <script src="../../_static/jquery.js"></script> <script src="../../_static/underscore.js"></script> <script src="../../_static/_sphinx_javascript_frameworks_compat.js"></script> <script src="../../_static/doctools.js"></script> <script src="../../_static/sphinx_highlight.js"></script> <script src="../../_static/scripts/sphinx-book-theme.js"></script> <script src="../../_static/design-tabs.js"></script> <script>DOCUMENTATION_OPTIONS.pagename = 'guides/parallel_training/flax_on_pjit';</script> <link rel="shortcut icon" href="../../_static/flax.png"/> <link rel="index" title="Index" href="../../genindex.html" /> <link rel="search" title="Search" href="../../search.html" /> <link rel="next" title="Model inspection" href="../model_inspection/index.html" /> <link rel="prev" title="Ensembling on multiple devices" href="ensembling.html" /> <meta name="viewport" content="width=device-width, initial-scale=1"/> <meta name="docsearch:language" content="en"/> <script async type="text/javascript" src="/_/static/javascript/readthedocs-addons.js"></script><meta name="readthedocs-project-slug" content="flax-linen" /><meta name="readthedocs-version-slug" content="latest" /><meta name="readthedocs-resolver-filename" content="/guides/parallel_training/flax_on_pjit.html" /><meta name="readthedocs-http-status" content="200" /></head> <body data-bs-spy="scroll" data-bs-target=".bd-toc-nav" data-offset="180" data-bs-root-margin="0px 0px -60%" data-default-mode=""> <div id="pst-skip-link" class="skip-link d-print-none"><a href="#main-content">Skip to main content</a></div> <div id="pst-scroll-pixel-helper"></div> <button type="button" class="btn rounded-pill" id="pst-back-to-top"> <i class="fa-solid fa-arrow-up"></i>Back to top</button> <input type="checkbox" class="sidebar-toggle" id="pst-primary-sidebar-checkbox"/> <label class="overlay overlay-primary" for="pst-primary-sidebar-checkbox"></label> <input type="checkbox" class="sidebar-toggle" id="pst-secondary-sidebar-checkbox"/> <label class="overlay overlay-secondary" for="pst-secondary-sidebar-checkbox"></label> <div class="search-button__wrapper"> <div class="search-button__overlay"></div> <div class="search-button__search-container"> <form class="bd-search d-flex align-items-center" action="../../search.html" method="get"> <i class="fa-solid fa-magnifying-glass"></i> <input type="search" class="form-control" name="q" id="search-input" placeholder="Search..." aria-label="Search..." autocomplete="off" autocorrect="off" autocapitalize="off" spellcheck="false"/> <span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd>K</kbd></span> </form></div> </div> <div class="pst-async-banner-revealer d-none"> <aside id="bd-header-version-warning" class="d-none d-print-none" aria-label="Version warning"></aside> </div> <aside class="bd-header-announcement" aria-label="Announcement"> <div class="bd-header-announcement__content"> <a href="https://flax.readthedocs.io/en/latest/index.html" style="text-decoration: none; color: white;" > This site covers the old Flax Linen API. <span style="color: lightgray;">[Explore the new <b>Flax NNX</b> API ✨]</span> </a> </div> </aside> <header class="bd-header navbar navbar-expand-lg bd-navbar d-print-none"> </header> <div class="bd-container"> <div class="bd-container__inner bd-page-width"> <div class="bd-sidebar-primary bd-sidebar"> <div class="sidebar-header-items sidebar-primary__section"> </div> <div class="sidebar-primary-items__start sidebar-primary__section"> <div class="sidebar-primary-item"> <a class="navbar-brand logo" href="../../index.html"> <img src="../../_static/flax.png" class="logo__image only-light" alt=" - Home"/> <script>document.write(`<img src="../../_static/flax.png" class="logo__image only-dark" alt=" - Home"/>`);</script> </a></div> <div class="sidebar-primary-item"> <script> document.write(` <button class="btn search-button-field search-button__button" title="Search" aria-label="Search" data-bs-placement="bottom" data-bs-toggle="tooltip"> <i class="fa-solid fa-magnifying-glass"></i> <span class="search-button__default-text">Search</span> <span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd class="kbd-shortcut__modifier">K</kbd></span> </button> `); </script></div> <div class="sidebar-primary-item"><nav class="bd-links bd-docs-nav" aria-label="Main"> <div class="bd-toc-item navbar-nav active"> <ul class="current nav bd-sidenav"> <li class="toctree-l1"><a class="reference internal" href="../../quick_start.html">Quick start</a></li> <li class="toctree-l1"><a class="reference internal" href="../flax_fundamentals/flax_basics.html">Flax Basics</a></li> <li class="toctree-l1 current active has-children"><a class="reference internal" href="../index.html">Guides</a><details open="open"><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul class="current"> <li class="toctree-l2 has-children"><a class="reference internal" href="../flax_fundamentals/index.html">Flax fundamentals</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l3"><a class="reference external" href="https://jax.readthedocs.io/en/latest/jax-101/index.html">JAX 101</a></li> <li class="toctree-l3"><a class="reference internal" href="../flax_fundamentals/flax_basics.html">Flax Basics</a></li> <li class="toctree-l3"><a class="reference internal" href="../flax_fundamentals/state_params.html">Managing Parameters and State</a></li> <li class="toctree-l3"><a class="reference internal" href="../flax_fundamentals/setup_or_nncompact.html"><code class="docutils literal notranslate"><span class="pre">setup</span></code> vs <code class="docutils literal notranslate"><span class="pre">compact</span></code></a></li> <li class="toctree-l3"><a class="reference internal" href="../flax_fundamentals/arguments.html">Dealing with Flax Module arguments</a></li> <li class="toctree-l3"><a class="reference internal" href="../flax_fundamentals/rng_guide.html">Randomness and PRNGs in Flax</a></li> </ul> </details></li> <li class="toctree-l2 has-children"><a class="reference internal" href="../data_preprocessing/index.html">Data preprocessing</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l3"><a class="reference internal" href="../data_preprocessing/full_eval.html">Processing the entire Dataset</a></li> <li class="toctree-l3"><a class="reference internal" href="../data_preprocessing/loading_datasets.html">Loading datasets</a></li> </ul> </details></li> <li class="toctree-l2 has-children"><a class="reference internal" href="../training_techniques/index.html">Training techniques</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l3"><a class="reference internal" href="../training_techniques/batch_norm.html">Batch normalization</a></li> <li class="toctree-l3"><a class="reference internal" href="../training_techniques/dropout.html">Dropout</a></li> <li class="toctree-l3"><a class="reference internal" href="../training_techniques/lr_schedule.html">Learning rate scheduling</a></li> <li class="toctree-l3"><a class="reference internal" href="../training_techniques/transfer_learning.html">Transfer learning</a></li> <li class="toctree-l3"><a class="reference internal" href="../training_techniques/use_checkpointing.html">Save and load checkpoints</a></li> </ul> </details></li> <li class="toctree-l2 current active has-children"><a class="reference internal" href="index.html">Parallel training</a><details open="open"><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul class="current"> <li class="toctree-l3"><a class="reference internal" href="ensembling.html">Ensembling on multiple devices</a></li> <li class="toctree-l3 current active"><a class="current reference internal" href="#">Scale up Flax Modules on multiple devices</a></li> </ul> </details></li> <li class="toctree-l2 has-children"><a class="reference internal" href="../model_inspection/index.html">Model inspection</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l3"><a class="reference internal" href="../model_inspection/model_surgery.html">Model surgery</a></li> <li class="toctree-l3"><a class="reference internal" href="../model_inspection/extracting_intermediates.html">Extracting intermediate values</a></li> </ul> </details></li> <li class="toctree-l2 has-children"><a class="reference internal" href="../converting_and_upgrading/index.html">Converting and upgrading</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l3"><a class="reference internal" href="../converting_and_upgrading/haiku_migration_guide.html">Migrating from Haiku to Flax</a></li> <li class="toctree-l3"><a class="reference internal" href="../converting_and_upgrading/convert_pytorch_to_flax.html">Convert PyTorch models to Flax</a></li> <li class="toctree-l3"><a class="reference internal" href="../converting_and_upgrading/orbax_upgrade_guide.html">Migrate checkpointing to Orbax</a></li> <li class="toctree-l3"><a class="reference internal" href="../converting_and_upgrading/optax_update_guide.html">Upgrading my codebase to Optax</a></li> <li class="toctree-l3"><a class="reference internal" href="../converting_and_upgrading/linen_upgrade_guide.html">Upgrading my codebase to Linen</a></li> <li class="toctree-l3"><a class="reference internal" href="../converting_and_upgrading/rnncell_upgrade_guide.html">RNNCellBase Upgrade Guide</a></li> <li class="toctree-l3"><a class="reference internal" href="../converting_and_upgrading/regular_dict_upgrade_guide.html">Migrate to regular dicts</a></li> </ul> </details></li> <li class="toctree-l2 has-children"><a class="reference internal" href="../quantization/index.html">Quantization</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l3"><a class="reference internal" href="../quantization/fp8_basics.html">User Guide on Using FP8</a></li> </ul> </details></li> <li class="toctree-l2"><a class="reference internal" href="../flax_sharp_bits.html">The Sharp Bits</a></li> </ul> </details></li> <li class="toctree-l1 has-children"><a class="reference internal" href="../../examples/index.html">Examples</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l2"><a class="reference internal" href="../../examples/core_examples.html">Core examples</a></li> <li class="toctree-l2"><a class="reference internal" href="../../examples/google_research_examples.html">Google Research examples</a></li> <li class="toctree-l2"><a class="reference internal" href="../../examples/repositories_that_use_flax.html">Repositories that use Flax</a></li> <li class="toctree-l2"><a class="reference internal" href="../../examples/community_examples.html">Community examples</a></li> </ul> </details></li> <li class="toctree-l1"><a class="reference internal" href="../../glossary.html">Glossary</a></li> <li class="toctree-l1"><a class="reference internal" href="../../faq.html">Frequently Asked Questions (FAQ)</a></li> <li class="toctree-l1 has-children"><a class="reference internal" href="../../developer_notes/index.html">Developer notes</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l2"><a class="reference internal" href="../../developer_notes/module_lifecycle.html">The Flax Module lifecycle</a></li> <li class="toctree-l2"><a class="reference internal" href="../../developer_notes/lift.html">Lifted transformations</a></li> <li class="toctree-l2"><a class="reference external" href="https://github.com/google/flax/tree/main/docs/flip">FLIPs</a></li> </ul> </details></li> <li class="toctree-l1"><a class="reference internal" href="../../philosophy.html">The Flax philosophy</a></li> <li class="toctree-l1"><a class="reference internal" href="../../contributing.html">How to contribute</a></li> <li class="toctree-l1 has-children"><a class="reference internal" href="../../api_reference/index.html">API Reference</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l2"><a class="reference internal" href="../../api_reference/flax.config.html">flax.config package</a></li> <li class="toctree-l2"><a class="reference internal" href="../../api_reference/flax.core.frozen_dict.html">flax.core.frozen_dict package</a></li> <li class="toctree-l2"><a class="reference internal" href="../../api_reference/flax.cursor.html">flax.cursor package</a></li> <li class="toctree-l2"><a class="reference internal" href="../../api_reference/flax.errors.html">flax.errors package</a></li> <li class="toctree-l2"><a class="reference internal" href="../../api_reference/flax.jax_utils.html">flax.jax_utils package</a></li> <li class="toctree-l2 has-children"><a class="reference internal" href="../../api_reference/flax.linen/index.html">flax.linen</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l3"><a class="reference internal" href="../../api_reference/flax.linen/module.html">Module</a></li> <li class="toctree-l3"><a class="reference internal" href="../../api_reference/flax.linen/init_apply.html">Init/Apply</a></li> <li class="toctree-l3"><a class="reference internal" href="../../api_reference/flax.linen/layers.html">Layers</a></li> <li class="toctree-l3"><a class="reference internal" href="../../api_reference/flax.linen/activation_functions.html">Activation functions</a></li> <li class="toctree-l3"><a class="reference internal" href="../../api_reference/flax.linen/initializers.html">Initializers</a></li> <li class="toctree-l3"><a class="reference internal" href="../../api_reference/flax.linen/transformations.html">Transformations</a></li> <li class="toctree-l3"><a class="reference internal" href="../../api_reference/flax.linen/inspection.html">Inspection</a></li> <li class="toctree-l3"><a class="reference internal" href="../../api_reference/flax.linen/variable.html">Variable dictionary</a></li> <li class="toctree-l3"><a class="reference internal" href="../../api_reference/flax.linen/spmd.html">SPMD</a></li> <li class="toctree-l3"><a class="reference internal" href="../../api_reference/flax.linen/decorators.html">Decorators</a></li> <li class="toctree-l3"><a class="reference internal" href="../../api_reference/flax.linen/profiling.html">Profiling</a></li> </ul> </details></li> <li class="toctree-l2"><a class="reference internal" href="../../api_reference/flax.serialization.html">flax.serialization package</a></li> <li class="toctree-l2"><a class="reference internal" href="../../api_reference/flax.struct.html">flax.struct package</a></li> <li class="toctree-l2"><a class="reference internal" href="../../api_reference/flax.traceback_util.html">flax.traceback_util package</a></li> <li class="toctree-l2"><a class="reference internal" href="../../api_reference/flax.training.html">flax.training package</a></li> <li class="toctree-l2"><a class="reference internal" href="../../api_reference/flax.traverse_util.html">flax.traverse_util package</a></li> </ul> </details></li> <li class="toctree-l1"><a class="reference external" href="https://flax.readthedocs.io/en/latest/index.html">Flax NNX</a></li> </ul> </div> </nav></div> </div> <div class="sidebar-primary-items__end sidebar-primary__section"> </div> <div id="rtd-footer-container"></div> </div> <main id="main-content" class="bd-main" role="main"> <div class="sbt-scroll-pixel-helper"></div> <div class="bd-content"> <div class="bd-article-container"> <div class="bd-header-article d-print-none"> <div class="header-article-items header-article__inner"> <div class="header-article-items__start"> <div class="header-article-item"><button class="sidebar-toggle primary-toggle btn btn-sm" title="Toggle primary sidebar" data-bs-placement="bottom" data-bs-toggle="tooltip"> <span class="fa-solid fa-bars"></span> </button></div> </div> <div class="header-article-items__end"> <div class="header-article-item"> <div class="article-header-buttons"> <div class="dropdown dropdown-launch-buttons"> <button class="btn dropdown-toggle" type="button" data-bs-toggle="dropdown" aria-expanded="false" aria-label="Launch interactive content"> <i class="fas fa-rocket"></i> </button> <ul class="dropdown-menu"> <li><a href="https://colab.research.google.com/github/google/flax/blob/master/docs/guides/parallel_training/flax_on_pjit.ipynb" target="_blank" class="btn btn-sm dropdown-item" title="Launch on Colab" data-bs-placement="left" data-bs-toggle="tooltip" > <span class="btn__icon-container"> <img alt="Colab logo" src="../../_static/images/logo_colab.png"> </span> <span class="btn__text-container">Colab</span> </a> </li> </ul> </div> <a href="https://github.com/google/flax" target="_blank" class="btn btn-sm btn-source-repository-button" title="Source repository" data-bs-placement="bottom" data-bs-toggle="tooltip" > <span class="btn__icon-container"> <i class="fab fa-github"></i> </span> </a> <div class="dropdown dropdown-download-buttons"> <button class="btn dropdown-toggle" type="button" data-bs-toggle="dropdown" aria-expanded="false" aria-label="Download this page"> <i class="fas fa-download"></i> </button> <ul class="dropdown-menu"> <li><a href="../../_sources/guides/parallel_training/flax_on_pjit.ipynb" target="_blank" class="btn btn-sm btn-download-source-button dropdown-item" title="Download source file" data-bs-placement="left" data-bs-toggle="tooltip" > <span class="btn__icon-container"> <i class="fas fa-file"></i> </span> <span class="btn__text-container">.ipynb</span> </a> </li> <li> <button onclick="window.print()" class="btn btn-sm btn-download-pdf-button dropdown-item" title="Print to PDF" data-bs-placement="left" data-bs-toggle="tooltip" > <span class="btn__icon-container"> <i class="fas fa-file-pdf"></i> </span> <span class="btn__text-container">.pdf</span> </button> </li> </ul> </div> <button onclick="toggleFullScreen()" class="btn btn-sm btn-fullscreen-button" title="Fullscreen mode" data-bs-placement="bottom" data-bs-toggle="tooltip" > <span class="btn__icon-container"> <i class="fas fa-expand"></i> </span> </button> <script> document.write(` <button class="btn btn-sm nav-link pst-navbar-icon theme-switch-button" title="light/dark" aria-label="light/dark" data-bs-placement="bottom" data-bs-toggle="tooltip"> <i class="theme-switch fa-solid fa-sun fa-lg" data-mode="light"></i> <i class="theme-switch fa-solid fa-moon fa-lg" data-mode="dark"></i> <i class="theme-switch fa-solid fa-circle-half-stroke fa-lg" data-mode="auto"></i> </button> `); </script> <script> document.write(` <button class="btn btn-sm pst-navbar-icon search-button search-button__button" title="Search" aria-label="Search" data-bs-placement="bottom" data-bs-toggle="tooltip"> <i class="fa-solid fa-magnifying-glass fa-lg"></i> </button> `); </script> <button class="sidebar-toggle secondary-toggle btn btn-sm" title="Toggle secondary sidebar" data-bs-placement="bottom" data-bs-toggle="tooltip"> <span class="fa-solid fa-list"></span> </button> </div></div> </div> </div> </div> <div id="jb-print-docs-body" class="onlyprint"> <h1>Scale up Flax Modules on multiple devices</h1> <!-- Table of contents --> <div id="print-main-content"> <div id="jb-print-toc"> <div> <h2> Contents </h2> </div> <nav aria-label="Page"> <ul class="visible nav section-nav flex-column"> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax-and-jax-jit-scaled-up">Flax and <code class="docutils literal notranslate"><span class="pre">jax.jit</span></code> scaled up</a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#setup">Setup</a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#define-a-layer">Define a layer</a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#define-a-model-with-flax-linen-scan-lifted-transformation">Define a model with <code class="docutils literal notranslate"><span class="pre">flax.linen.scan</span></code> lifted transformation</a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#specify-sharding">Specify sharding</a><ul class="nav section-nav flex-column"> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#the-input-s-sharding">The input’s sharding</a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#the-output-s-sharding">The output’s sharding</a></li> </ul> </li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#compile-the-code">Compile the code</a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#inspect-the-module-output">Inspect the Module output</a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#compile-the-train-step-and-inference">Compile the train step and inference</a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#profiling">Profiling</a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#logical-axis-annotation">Logical axis annotation</a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#when-to-use-device-axis-logical-axis">When to use device axis / logical axis</a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#save-the-data">Save the data</a></li> </ul> </nav> </div> </div> </div> <div id="searchbox"></div> <article class="bd-article"> <div class="tex2jax_ignore mathjax_ignore section" id="scale-up-flax-modules-on-multiple-devices"> <h1>Scale up Flax Modules on multiple devices<a class="headerlink" href="#scale-up-flax-modules-on-multiple-devices" title="Permalink to this heading">#</a></h1> <p>This guide shows how to scale up <a class="reference external" href="https://flax.readthedocs.io/en/latest/developer_notes/module_lifecycle.html">Flax Modules</a> on multiple devices and hosts using <a class="reference external" href="https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html"><code class="docutils literal notranslate"><span class="pre">jax.jit</span></code></a> (formerly <a class="reference external" href="https://jax.readthedocs.io/en/latest/jax.experimental.pjit.html#module-jax.experimental.pjit"><code class="docutils literal notranslate"><span class="pre">experimental.pjit</span></code></a>) and <a class="reference external" href="https://flax.readthedocs.io/en/latest/api_reference/flax.linen/index.html"><code class="docutils literal notranslate"><span class="pre">flax.linen</span></code></a>.</p> <div class="section" id="flax-and-jax-jit-scaled-up"> <h2>Flax and <code class="docutils literal notranslate"><span class="pre">jax.jit</span></code> scaled up<a class="headerlink" href="#flax-and-jax-jit-scaled-up" title="Permalink to this heading">#</a></h2> <p><a class="reference external" href="https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html"><code class="docutils literal notranslate"><span class="pre">jax.jit</span></code></a> follows the <a class="reference external" href="https://jax.readthedocs.io/en/latest/glossary.html#term-SPMD">Single Program Multi Data (SPMD)</a> paradigm and automatically compiles your code to run it on multiple devices. You need to only specify how you want the input and output of your code to be partitioned, and the compiler will figure out how to: 1) partition everything inside; and 2) compile inter-device communications.</p> <p>Flax provides several functionalities that can help you use auto-SPMD on <a class="reference external" href="https://flax.readthedocs.io/en/latest/developer_notes/module_lifecycle.html">Flax Modules</a>, including:</p> <ol class="arabic simple"> <li><p>An interface to specify partitions of your data when defining <a class="reference external" href="https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html"><code class="docutils literal notranslate"><span class="pre">flax.linen.Module</span></code></a>.</p></li> <li><p>Utility functions to generate the sharding information that <code class="docutils literal notranslate"><span class="pre">jax.jit</span></code> requires to run.</p></li> <li><p>An interface to customize your axis names called “logical axis annotations” to decouple both your Module code and partition plan to experiment with different partition layouts more easily.</p></li> </ol> <p>You can learn more about <code class="docutils literal notranslate"><span class="pre">jax.jit</span></code> APIs for scaling up in <a class="reference external" href="https://jax.readthedocs.io/en/latest/multi_process.html">JAX in multi-process environments</a> and <a class="reference external" href="https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html">Distributed arrays and automatic parallelization</a> on JAX’s documentation site.</p> </div> <div class="section" id="setup"> <h2>Setup<a class="headerlink" href="#setup" title="Permalink to this heading">#</a></h2> <p>Import some necessary dependencies.</p> <p><strong>Note:</strong> This guide uses the <code class="docutils literal notranslate"><span class="pre">--xla_force_host_platform_device_count=8</span></code> flag to emulate multiple devices in a CPU environment in a Google Colab/Jupyter Notebook. You don’t need this if you are already using a multi-device TPU environment.</p> <div class="cell tag_skip-execution docutils container"> <div class="cell_input docutils container"> <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># Once Flax v0.6.10 is released, there is no need to do this.</span> <span class="c1"># ! pip3 install -qq "git+https://github.com/google/flax.git@main#egg=flax"</span> </pre></div> </div> </div> </div> <div class="cell docutils container"> <div class="cell_input docutils container"> <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">os</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s2">"XLA_FLAGS"</span><span class="p">]</span> <span class="o">=</span> <span class="s1">'--xla_force_host_platform_device_count=8'</span> </pre></div> </div> </div> </div> <div class="cell docutils container"> <div class="cell_input docutils container"> <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">functools</span> <span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Optional</span><span class="p">,</span> <span class="n">Callable</span> <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> <span class="kn">import</span> <span class="nn">jax</span> <span class="kn">from</span> <span class="nn">jax</span> <span class="kn">import</span> <span class="n">lax</span><span class="p">,</span> <span class="n">random</span><span class="p">,</span> <span class="n">numpy</span> <span class="k">as</span> <span class="n">jnp</span> <span class="kn">import</span> <span class="nn">flax</span> <span class="kn">from</span> <span class="nn">flax</span> <span class="kn">import</span> <span class="n">struct</span><span class="p">,</span> <span class="n">traverse_util</span><span class="p">,</span> <span class="n">linen</span> <span class="k">as</span> <span class="n">nn</span> <span class="kn">from</span> <span class="nn">flax.core</span> <span class="kn">import</span> <span class="n">freeze</span><span class="p">,</span> <span class="n">unfreeze</span> <span class="kn">from</span> <span class="nn">flax.training</span> <span class="kn">import</span> <span class="n">train_state</span><span class="p">,</span> <span class="n">checkpoints</span> <span class="kn">import</span> <span class="nn">optax</span> <span class="c1"># Optax for common losses and optimizers.</span> </pre></div> </div> </div> <div class="cell_output docutils container"> <div class="output stderr highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>2024-11-20 21:52:15.791706: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered WARNING: All log messages before absl::InitializeLog() is called are written to STDERR E0000 00:00:1732139535.811354 945 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered E0000 00:00:1732139535.817305 945 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered </pre></div> </div> </div> </div> <div class="cell docutils container"> <div class="cell_input docutils container"> <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s1">'We have 8 fake JAX devices now: </span><span class="si">{</span><span class="n">jax</span><span class="o">.</span><span class="n">devices</span><span class="p">()</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span> </pre></div> </div> </div> <div class="cell_output docutils container"> <div class="output stream highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>We have 8 fake JAX devices now: [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)] </pre></div> </div> </div> </div> <p>The code below shows how to import and set up the JAX-level device API, following JAX’s <a class="reference external" href="https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html">Distributed arrays and automatic parallelization</a> guide:</p> <ol class="arabic simple"> <li><p>Start a 2x4 device <code class="docutils literal notranslate"><span class="pre">mesh</span></code> (8 devices) using JAX’s <code class="docutils literal notranslate"><span class="pre">mesh_utils.create_device_mesh</span></code>. This layout is the same as the one of a <a class="reference external" href="https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#single_tpu_board">TPU v3-8</a>.</p></li> <li><p>Annotate each axis with a name using the <code class="docutils literal notranslate"><span class="pre">axis_names</span></code> parameter in <code class="docutils literal notranslate"><span class="pre">jax.sharding.Mesh</span></code>. A typical way to annotate axis names is <code class="docutils literal notranslate"><span class="pre">axis_name=('data',</span> <span class="pre">'model')</span></code>, where:</p></li> </ol> <ul class="simple"> <li><p><code class="docutils literal notranslate"><span class="pre">'data'</span></code>: the mesh dimension used for data-parallel sharding of the batch dimension of inputs and activations.</p></li> <li><p><code class="docutils literal notranslate"><span class="pre">'model'</span></code>: the mesh dimension used for sharding parameters of the model across devices.</p></li> </ul> <ol class="arabic simple" start="3"> <li><p>Make a simple utility function <code class="docutils literal notranslate"><span class="pre">mesh_sharding</span></code> for generating a sharding object from the mesh and any layout.</p></li> </ol> <div class="cell docutils container"> <div class="cell_input docutils container"> <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">jax.sharding</span> <span class="kn">import</span> <span class="n">Mesh</span><span class="p">,</span> <span class="n">PartitionSpec</span><span class="p">,</span> <span class="n">NamedSharding</span> <span class="kn">from</span> <span class="nn">jax.lax</span> <span class="kn">import</span> <span class="n">with_sharding_constraint</span> <span class="kn">from</span> <span class="nn">jax.experimental</span> <span class="kn">import</span> <span class="n">mesh_utils</span> </pre></div> </div> </div> </div> <div class="cell docutils container"> <div class="cell_input docutils container"> <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># Create a mesh and annotate each axis with a name.</span> <span class="n">device_mesh</span> <span class="o">=</span> <span class="n">mesh_utils</span><span class="o">.</span><span class="n">create_device_mesh</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">4</span><span class="p">))</span> <span class="nb">print</span><span class="p">(</span><span class="n">device_mesh</span><span class="p">)</span> <span class="n">mesh</span> <span class="o">=</span> <span class="n">Mesh</span><span class="p">(</span><span class="n">devices</span><span class="o">=</span><span class="n">device_mesh</span><span class="p">,</span> <span class="n">axis_names</span><span class="o">=</span><span class="p">(</span><span class="s1">'data'</span><span class="p">,</span> <span class="s1">'model'</span><span class="p">))</span> <span class="nb">print</span><span class="p">(</span><span class="n">mesh</span><span class="p">)</span> <span class="k">def</span> <span class="nf">mesh_sharding</span><span class="p">(</span><span class="n">pspec</span><span class="p">:</span> <span class="n">PartitionSpec</span><span class="p">)</span> <span class="o">-></span> <span class="n">NamedSharding</span><span class="p">:</span> <span class="k">return</span> <span class="n">NamedSharding</span><span class="p">(</span><span class="n">mesh</span><span class="p">,</span> <span class="n">pspec</span><span class="p">)</span> </pre></div> </div> </div> <div class="cell_output docutils container"> <div class="output stream highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>[[CpuDevice(id=0) CpuDevice(id=1) CpuDevice(id=2) CpuDevice(id=3)] [CpuDevice(id=4) CpuDevice(id=5) CpuDevice(id=6) CpuDevice(id=7)]] Mesh('data': 2, 'model': 4) </pre></div> </div> </div> </div> </div> <div class="section" id="define-a-layer"> <h2>Define a layer<a class="headerlink" href="#define-a-layer" title="Permalink to this heading">#</a></h2> <p>Before defining a simple model, create an example layer called <code class="docutils literal notranslate"><span class="pre">DotReluDot</span></code> (by subclassing <code class="docutils literal notranslate"><span class="pre">flax.linen.Module</span></code>). The layer creates two parameters <code class="docutils literal notranslate"><span class="pre">W1</span></code> and <code class="docutils literal notranslate"><span class="pre">W2</span></code> for dot product multiplication, and uses the <code class="docutils literal notranslate"><span class="pre">jax.nn.relu</span></code> (ReLU) activation function in-between.</p> <p>To shard the parameters efficiently, apply the following APIs to annotate the parameters and intermediate variables:</p> <ol class="arabic simple"> <li><p>Use <a class="reference external" href="https://flax.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.with_partitioning"><code class="docutils literal notranslate"><span class="pre">flax.linen.with_partitioning</span></code></a> to decorate the initializer function when creating sub-layers or raw parameters.</p></li> <li><p>Apply <a class="reference external" href="https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html"><code class="docutils literal notranslate"><span class="pre">jax.lax.with_sharding_constraint</span></code></a> (formerly, <code class="docutils literal notranslate"><span class="pre">pjit.with_sharding_constraint</span></code>) to annotate intermediate variables like <code class="docutils literal notranslate"><span class="pre">y</span></code> and <code class="docutils literal notranslate"><span class="pre">z</span></code> to force a particular sharding pattern when the ideal constraint is known.</p></li> </ol> <ul class="simple"> <li><p>This step is optional, but can sometimes help auto-SPMD to partition efficiently. In the example below, the call is not required, because XLA will figure out the same sharding layout for <code class="docutils literal notranslate"><span class="pre">y</span></code> and <code class="docutils literal notranslate"><span class="pre">z</span></code> regardless.</p></li> </ul> <div class="cell docutils container"> <div class="cell_input docutils container"> <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">DotReluDot</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="n">depth</span><span class="p">:</span> <span class="nb">int</span> <span class="n">dense_init</span><span class="p">:</span> <span class="n">Callable</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">initializers</span><span class="o">.</span><span class="n">xavier_normal</span><span class="p">()</span> <span class="nd">@nn</span><span class="o">.</span><span class="n">compact</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="n">y</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">depth</span><span class="p">,</span> <span class="n">kernel_init</span><span class="o">=</span><span class="n">nn</span><span class="o">.</span><span class="n">with_partitioning</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dense_init</span><span class="p">,</span> <span class="p">(</span><span class="kc">None</span><span class="p">,</span> <span class="s1">'model'</span><span class="p">)),</span> <span class="n">use_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="c1"># or overwrite with `bias_init`</span> <span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">y</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="n">y</span><span class="p">)</span> <span class="c1"># Force a local sharding annotation.</span> <span class="n">y</span> <span class="o">=</span> <span class="n">with_sharding_constraint</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="n">mesh_sharding</span><span class="p">(</span><span class="n">PartitionSpec</span><span class="p">(</span><span class="s1">'data'</span><span class="p">,</span> <span class="s1">'model'</span><span class="p">)))</span> <span class="n">W2</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">param</span><span class="p">(</span> <span class="s1">'W2'</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">with_partitioning</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dense_init</span><span class="p">,</span> <span class="p">(</span><span class="s1">'model'</span><span class="p">,</span> <span class="kc">None</span><span class="p">)),</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">depth</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]))</span> <span class="n">z</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="n">W2</span><span class="p">)</span> <span class="c1"># Force a local sharding annotation.</span> <span class="n">z</span> <span class="o">=</span> <span class="n">with_sharding_constraint</span><span class="p">(</span><span class="n">z</span><span class="p">,</span> <span class="n">mesh_sharding</span><span class="p">(</span><span class="n">PartitionSpec</span><span class="p">(</span><span class="s1">'data'</span><span class="p">,</span> <span class="kc">None</span><span class="p">)))</span> <span class="c1"># Return a tuple to conform with the API `flax.linen.scan` as shown in the cell below.</span> <span class="k">return</span> <span class="n">z</span><span class="p">,</span> <span class="kc">None</span> </pre></div> </div> </div> </div> <p>Note that device axis names like <code class="docutils literal notranslate"><span class="pre">'data'</span></code>, <code class="docutils literal notranslate"><span class="pre">'model'</span></code> or <code class="docutils literal notranslate"><span class="pre">None</span></code> are passed into both <a class="reference external" href="https://flax.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.with_partitioning"><code class="docutils literal notranslate"><span class="pre">flax.linen.with_partitioning</span></code></a> and <a class="reference external" href="https://github.com/google/jax/blob/main/jax/_src/pjit.py#L1516"><code class="docutils literal notranslate"><span class="pre">jax.lax.with_sharding_constraint</span></code></a> API calls. This refers to how each dimension of this data should be sharded — either across one of the device mesh dimensions, or not sharded at all.</p> <p>For example:</p> <ul class="simple"> <li><p>When you define <code class="docutils literal notranslate"><span class="pre">W1</span></code> with shape <code class="docutils literal notranslate"><span class="pre">(x.shape[-1],</span> <span class="pre">self.depth)</span></code> and annotate as <code class="docutils literal notranslate"><span class="pre">(None,</span> <span class="pre">'model')</span></code>:</p> <ul> <li><p>The first dimension (of length <code class="docutils literal notranslate"><span class="pre">x.shape[-1]</span></code>) will be replicated across all devices.</p></li> <li><p>The second dimension (of length <code class="docutils literal notranslate"><span class="pre">self.depth</span></code>) will be sharded over the <code class="docutils literal notranslate"><span class="pre">'model'</span></code> axis of the device mesh. This means <code class="docutils literal notranslate"><span class="pre">W1</span></code> will be sharded 4-way on devices <code class="docutils literal notranslate"><span class="pre">(0,</span> <span class="pre">4)</span></code>, <code class="docutils literal notranslate"><span class="pre">(1,</span> <span class="pre">5)</span></code>, <code class="docutils literal notranslate"><span class="pre">(2,</span> <span class="pre">6)</span></code> and <code class="docutils literal notranslate"><span class="pre">(3,</span> <span class="pre">7)</span></code>, on this dimension.</p></li> </ul> </li> <li><p>When you annotate the output <code class="docutils literal notranslate"><span class="pre">z</span></code> as <code class="docutils literal notranslate"><span class="pre">('data',</span> <span class="pre">None)</span></code>:</p> <ul> <li><p>The first dimension — the batch dimension — will be sharded over the <code class="docutils literal notranslate"><span class="pre">'data'</span></code> axis. This means half of the batch will be processed on devices <code class="docutils literal notranslate"><span class="pre">0-3</span></code> (first four devices), and another half on devices <code class="docutils literal notranslate"><span class="pre">4-7</span></code> (the remaining four devices).</p></li> <li><p>The second dimension — the data depth dimension — will be replicated across all devices.</p></li> </ul> </li> </ul> </div> <div class="section" id="define-a-model-with-flax-linen-scan-lifted-transformation"> <h2>Define a model with <code class="docutils literal notranslate"><span class="pre">flax.linen.scan</span></code> lifted transformation<a class="headerlink" href="#define-a-model-with-flax-linen-scan-lifted-transformation" title="Permalink to this heading">#</a></h2> <p>Having created <code class="docutils literal notranslate"><span class="pre">DotReluDot</span></code>, you can now define the <code class="docutils literal notranslate"><span class="pre">MLP</span></code> model (by subclassing <a class="reference external" href="https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module"><code class="docutils literal notranslate"><span class="pre">flax.linen.Module</span></code></a>) as multiple layers of <code class="docutils literal notranslate"><span class="pre">DotReluDot</span></code>.</p> <p>To replicate identical layers, you can either use <a class="reference external" href="https://flax.readthedocs.io/en/latest/api_reference/flax.linen/transformations.html#flax.linen.scan"><code class="docutils literal notranslate"><span class="pre">flax.linen.scan</span></code></a>, or a for-loop:</p> <ul class="simple"> <li><p><code class="docutils literal notranslate"><span class="pre">flax.linen.scan</span></code> can provide faster compilation times.</p></li> <li><p>The for-loop can be faster on runtime.</p></li> </ul> <p>The code below shows how to apply both methods, and default with the for-loop, so that all the parameters are two-dimensional and you can visualize their sharding.</p> <p>The <code class="docutils literal notranslate"><span class="pre">flax.linen.scan</span></code> code is just to show that this API works with <a class="reference external" href="https://flax.readthedocs.io/en/latest/developer_notes/lift.html#supported-transformations">Flax lifted transforms</a>.</p> <div class="cell docutils container"> <div class="cell_input docutils container"> <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">MLP</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="n">num_layers</span><span class="p">:</span> <span class="nb">int</span> <span class="n">depth</span><span class="p">:</span> <span class="nb">int</span> <span class="n">use_scan</span><span class="p">:</span> <span class="nb">bool</span> <span class="nd">@nn</span><span class="o">.</span><span class="n">compact</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_scan</span><span class="p">:</span> <span class="n">x</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">scan</span><span class="p">(</span><span class="n">DotReluDot</span><span class="p">,</span> <span class="n">length</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">,</span> <span class="n">variable_axes</span><span class="o">=</span><span class="p">{</span><span class="s2">"params"</span><span class="p">:</span> <span class="mi">0</span><span class="p">},</span> <span class="n">split_rngs</span><span class="o">=</span><span class="p">{</span><span class="s2">"params"</span><span class="p">:</span> <span class="kc">True</span><span class="p">},</span> <span class="n">metadata_params</span><span class="o">=</span><span class="p">{</span><span class="n">nn</span><span class="o">.</span><span class="n">PARTITION_NAME</span><span class="p">:</span> <span class="kc">None</span><span class="p">}</span> <span class="p">)(</span><span class="bp">self</span><span class="o">.</span><span class="n">depth</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="k">else</span><span class="p">:</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">):</span> <span class="n">x</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">DotReluDot</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">depth</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">x</span> </pre></div> </div> </div> </div> <p>Now, create a <code class="docutils literal notranslate"><span class="pre">model</span></code> instance, and a sample input <code class="docutils literal notranslate"><span class="pre">x</span></code>.</p> <div class="cell docutils container"> <div class="cell_input docutils container"> <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># MLP hyperparameters.</span> <span class="n">BATCH</span><span class="p">,</span> <span class="n">LAYERS</span><span class="p">,</span> <span class="n">DEPTH</span><span class="p">,</span> <span class="n">USE_SCAN</span> <span class="o">=</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">1024</span><span class="p">,</span> <span class="kc">False</span> <span class="c1"># Create fake inputs.</span> <span class="n">x</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="n">BATCH</span><span class="p">,</span> <span class="n">DEPTH</span><span class="p">))</span> <span class="c1"># Initialize a PRNG key.</span> <span class="n">k</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="c1"># Create an Optax optimizer.</span> <span class="n">optimizer</span> <span class="o">=</span> <span class="n">optax</span><span class="o">.</span><span class="n">adam</span><span class="p">(</span><span class="n">learning_rate</span><span class="o">=</span><span class="mf">0.001</span><span class="p">)</span> <span class="c1"># Instantiate the model.</span> <span class="n">model</span> <span class="o">=</span> <span class="n">MLP</span><span class="p">(</span><span class="n">LAYERS</span><span class="p">,</span> <span class="n">DEPTH</span><span class="p">,</span> <span class="n">USE_SCAN</span><span class="p">)</span> </pre></div> </div> </div> </div> </div> <div class="section" id="specify-sharding"> <h2>Specify sharding<a class="headerlink" href="#specify-sharding" title="Permalink to this heading">#</a></h2> <p>Next, you need to tell <code class="docutils literal notranslate"><span class="pre">jax.jit</span></code> how to shard our data across devices.</p> <div class="section" id="the-input-s-sharding"> <h3>The input’s sharding<a class="headerlink" href="#the-input-s-sharding" title="Permalink to this heading">#</a></h3> <p>For data parallelism, you can shard the batched <em>input</em> <code class="docutils literal notranslate"><span class="pre">x</span></code> across the <code class="docutils literal notranslate"><span class="pre">data</span></code> axis by denoting the batch axis as <code class="docutils literal notranslate"><span class="pre">'data'</span></code>. Then, use <a class="reference external" href="https://jax.readthedocs.io/en/latest/_autosummary/jax.device_put.html"><code class="docutils literal notranslate"><span class="pre">jax.device_put</span></code></a> to place it onto the correct <code class="docutils literal notranslate"><span class="pre">device</span></code>s.</p> <div class="cell docutils container"> <div class="cell_input docutils container"> <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">x_sharding</span> <span class="o">=</span> <span class="n">mesh_sharding</span><span class="p">(</span><span class="n">PartitionSpec</span><span class="p">(</span><span class="s1">'data'</span><span class="p">,</span> <span class="kc">None</span><span class="p">))</span> <span class="c1"># dimensions: (batch, length)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">device_put</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">x_sharding</span><span class="p">)</span> <span class="n">jax</span><span class="o">.</span><span class="n">debug</span><span class="o">.</span><span class="n">visualize_array_sharding</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> </pre></div> </div> </div> <div class="cell_output docutils container"> <div class="output text_html"><pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> CPU 0,1,2,3 </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> CPU 4,5,6,7 </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span> </pre> </div></div> </div> </div> <div class="section" id="the-output-s-sharding"> <h3>The output’s sharding<a class="headerlink" href="#the-output-s-sharding" title="Permalink to this heading">#</a></h3> <p>You need to compile <code class="docutils literal notranslate"><span class="pre">model.init()</span></code> (that is, <a class="reference external" href="https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.init"><code class="docutils literal notranslate"><span class="pre">flax.linen.Module.init()</span></code></a>), and its output as a pytree of parameters. Additionally, you may sometimes need wrap it with a <a class="reference external" href="https://flax.readthedocs.io/en/latest/api_reference/flax.training.html#flax.training.train_state.TrainState"><code class="docutils literal notranslate"><span class="pre">flax.training.train_state</span></code></a> to track other variables, such as optimizer states, and that would make the output an even more complex pytree.</p> <p>To achieve this, luckily, you don’t have to hardcode the output’s sharding by hand. Instead, you can:</p> <ol class="arabic simple"> <li><p>Evaluate <code class="docutils literal notranslate"><span class="pre">model.init</span></code> (in this case, a wrapper of it) abstractly using <a class="reference external" href="https://jax.readthedocs.io/en/latest/_autosummary/jax.eval_shape.html"><code class="docutils literal notranslate"><span class="pre">jax.eval_shape</span></code></a>.</p></li> <li><p>Use <a class="reference external" href="https://flax.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.get_sharding"><code class="docutils literal notranslate"><span class="pre">flax.linen.get_sharding</span></code></a> to automatically generate the <code class="docutils literal notranslate"><span class="pre">jax.sharding.NamedSharding</span></code>.</p> <ul class="simple"> <li><p>This step utilizes the <a class="reference external" href="https://flax.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.with_partitioning"><code class="docutils literal notranslate"><span class="pre">flax.linen.with_partitioning</span></code></a> annotations in the earlier definition to generate the correct sharding for the parameters.</p></li> </ul> </li> </ol> <div class="cell docutils container"> <div class="cell_input docutils container"> <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">init_fn</span><span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">):</span> <span class="n">variables</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="c1"># Initialize the model.</span> <span class="n">state</span> <span class="o">=</span> <span class="n">train_state</span><span class="o">.</span><span class="n">TrainState</span><span class="o">.</span><span class="n">create</span><span class="p">(</span> <span class="c1"># Create a `TrainState`.</span> <span class="n">apply_fn</span><span class="o">=</span><span class="n">model</span><span class="o">.</span><span class="n">apply</span><span class="p">,</span> <span class="n">params</span><span class="o">=</span><span class="n">variables</span><span class="p">[</span><span class="s1">'params'</span><span class="p">],</span> <span class="n">tx</span><span class="o">=</span><span class="n">optimizer</span><span class="p">)</span> <span class="k">return</span> <span class="n">state</span> </pre></div> </div> </div> </div> <div class="cell docutils container"> <div class="cell_input docutils container"> <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># Create an abstract closure to wrap the function before feeding it in</span> <span class="c1"># because `jax.eval_shape` only takes pytrees as arguments.</span> <span class="n">abstract_variables</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">eval_shape</span><span class="p">(</span> <span class="n">functools</span><span class="o">.</span><span class="n">partial</span><span class="p">(</span><span class="n">init_fn</span><span class="p">,</span> <span class="n">model</span><span class="o">=</span><span class="n">model</span><span class="p">,</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">optimizer</span><span class="p">),</span> <span class="n">k</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="c1"># This `state_sharding` has the same pytree structure as `state`, the output</span> <span class="c1"># of the `init_fn`.</span> <span class="n">state_sharding</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">get_sharding</span><span class="p">(</span><span class="n">abstract_variables</span><span class="p">,</span> <span class="n">mesh</span><span class="p">)</span> <span class="n">state_sharding</span> </pre></div> </div> </div> <div class="cell_output docutils container"> <div class="output text_plain highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>TrainState(step=NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(), memory_kind=unpinned_host), apply_fn=<bound method Module.apply of MLP( # attributes num_layers = 4 depth = 1024 use_scan = False )>, params={'DotReluDot_0': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'), memory_kind=unpinned_host)}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None), memory_kind=unpinned_host)}, 'DotReluDot_1': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'), memory_kind=unpinned_host)}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None), memory_kind=unpinned_host)}, 'DotReluDot_2': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'), memory_kind=unpinned_host)}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None), memory_kind=unpinned_host)}, 'DotReluDot_3': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'), memory_kind=unpinned_host)}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None), memory_kind=unpinned_host)}}, tx=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn at 0x7fc60ff228c0>, update=<function chain.<locals>.update_fn at 0x7fc60ff22290>), opt_state=(ScaleByAdamState(count=NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(), memory_kind=unpinned_host), mu={'DotReluDot_0': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'), memory_kind=unpinned_host)}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None), memory_kind=unpinned_host)}, 'DotReluDot_1': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'), memory_kind=unpinned_host)}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None), memory_kind=unpinned_host)}, 'DotReluDot_2': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'), memory_kind=unpinned_host)}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None), memory_kind=unpinned_host)}, 'DotReluDot_3': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'), memory_kind=unpinned_host)}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None), memory_kind=unpinned_host)}}, nu={'DotReluDot_0': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'), memory_kind=unpinned_host)}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None), memory_kind=unpinned_host)}, 'DotReluDot_1': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'), memory_kind=unpinned_host)}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None), memory_kind=unpinned_host)}, 'DotReluDot_2': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'), memory_kind=unpinned_host)}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None), memory_kind=unpinned_host)}, 'DotReluDot_3': {'Dense_0': {'kernel': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'), memory_kind=unpinned_host)}, 'W2': NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec('model', None), memory_kind=unpinned_host)}}), EmptyState())) </pre></div> </div> </div> </div> </div> </div> <div class="section" id="compile-the-code"> <h2>Compile the code<a class="headerlink" href="#compile-the-code" title="Permalink to this heading">#</a></h2> <p>Now you can apply <a class="reference external" href="https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html"><code class="docutils literal notranslate"><span class="pre">jax.jit</span></code></a> to your <code class="docutils literal notranslate"><span class="pre">init_fn</span></code>, but with two extra arguments: <code class="docutils literal notranslate"><span class="pre">in_shardings</span></code> and <code class="docutils literal notranslate"><span class="pre">out_shardings</span></code>.</p> <p>Run it to get the <code class="docutils literal notranslate"><span class="pre">initialized_state</span></code>, in which parameters are sharded exactly as instructed:</p> <div class="cell docutils container"> <div class="cell_input docutils container"> <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">jit_init_fn</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">jit</span><span class="p">(</span><span class="n">init_fn</span><span class="p">,</span> <span class="n">static_argnums</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">in_shardings</span><span class="o">=</span><span class="p">(</span><span class="n">mesh_sharding</span><span class="p">(()),</span> <span class="n">x_sharding</span><span class="p">),</span> <span class="c1"># PRNG key and x</span> <span class="n">out_shardings</span><span class="o">=</span><span class="n">state_sharding</span><span class="p">)</span> <span class="n">initialized_state</span> <span class="o">=</span> <span class="n">jit_init_fn</span><span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">model</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">)</span> <span class="c1"># for weight, partitioned in initialized_state.params['DotReluDot_0'].items():</span> <span class="c1"># print(f'Sharding of {weight}: {partitioned.names}')</span> <span class="n">jax</span><span class="o">.</span><span class="n">debug</span><span class="o">.</span><span class="n">visualize_array_sharding</span><span class="p">(</span><span class="n">initialized_state</span><span class="o">.</span><span class="n">params</span><span class="p">[</span><span class="s1">'DotReluDot_0'</span><span class="p">][</span><span class="s1">'Dense_0'</span><span class="p">][</span><span class="s1">'kernel'</span><span class="p">]</span><span class="o">.</span><span class="n">value</span><span class="p">)</span> <span class="n">jax</span><span class="o">.</span><span class="n">debug</span><span class="o">.</span><span class="n">visualize_array_sharding</span><span class="p">(</span><span class="n">initialized_state</span><span class="o">.</span><span class="n">params</span><span class="p">[</span><span class="s1">'DotReluDot_0'</span><span class="p">][</span><span class="s1">'W2'</span><span class="p">]</span><span class="o">.</span><span class="n">value</span><span class="p">)</span> </pre></div> </div> </div> <div class="cell_output docutils container"> <div class="output text_html"><pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a"> </span><span style="color: #000000; text-decoration-color: #000000; background-color: #b5cf6b"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a"> </span><span style="color: #000000; text-decoration-color: #000000; background-color: #b5cf6b"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a"> </span><span style="color: #000000; text-decoration-color: #000000; background-color: #b5cf6b"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a"> </span><span style="color: #000000; text-decoration-color: #000000; background-color: #b5cf6b"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a"> </span><span style="color: #000000; text-decoration-color: #000000; background-color: #b5cf6b"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> CPU 0,4 </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6"> CPU 1,5 </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a"> CPU 2,6 </span><span style="color: #000000; text-decoration-color: #000000; background-color: #b5cf6b"> CPU 3,7 </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a"> </span><span style="color: #000000; text-decoration-color: #000000; background-color: #b5cf6b"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a"> </span><span style="color: #000000; text-decoration-color: #000000; background-color: #b5cf6b"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a"> </span><span style="color: #000000; text-decoration-color: #000000; background-color: #b5cf6b"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a"> </span><span style="color: #000000; text-decoration-color: #000000; background-color: #b5cf6b"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a"> </span><span style="color: #000000; text-decoration-color: #000000; background-color: #b5cf6b"> </span> </pre> </div><div class="output text_html"><pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> CPU 0,4 </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6"> CPU 1,5 </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a"> CPU 2,6 </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a"> </span> <span style="color: #000000; text-decoration-color: #000000; background-color: #b5cf6b"> </span> <span style="color: #000000; text-decoration-color: #000000; background-color: #b5cf6b"> CPU 3,7 </span> <span style="color: #000000; text-decoration-color: #000000; background-color: #b5cf6b"> </span> </pre> </div></div> </div> </div> <div class="section" id="inspect-the-module-output"> <h2>Inspect the Module output<a class="headerlink" href="#inspect-the-module-output" title="Permalink to this heading">#</a></h2> <p>Note that in the output of <code class="docutils literal notranslate"><span class="pre">initialized_state</span></code>, the <code class="docutils literal notranslate"><span class="pre">params</span></code> <code class="docutils literal notranslate"><span class="pre">W1</span></code> and <code class="docutils literal notranslate"><span class="pre">W2</span></code> are of type <a class="reference external" href="https://flax.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.Partitioned"><code class="docutils literal notranslate"><span class="pre">flax.linen.Partitioned</span></code></a>. This is a wrapper around the actual <code class="docutils literal notranslate"><span class="pre">jax.Array</span></code> that allows Flax to record the axis names associated with it.</p> <p>You can access the raw <code class="docutils literal notranslate"><span class="pre">jax.Array</span></code>s by calling <code class="docutils literal notranslate"><span class="pre">flax.linen.meta.unbox()</span></code> upon the dictionary, or call <code class="docutils literal notranslate"><span class="pre">.value</span></code> upon individual variable. You can also use <code class="docutils literal notranslate"><span class="pre">flax.linen.meta.replace_boxed()</span></code> to change the underlying <code class="docutils literal notranslate"><span class="pre">jax.Array</span></code> without modifying the sharding annotations.</p> <div class="cell docutils container"> <div class="cell_input docutils container"> <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="nb">print</span><span class="p">(</span><span class="nb">type</span><span class="p">(</span><span class="n">initialized_state</span><span class="o">.</span><span class="n">params</span><span class="p">[</span><span class="s1">'DotReluDot_0'</span><span class="p">][</span><span class="s1">'Dense_0'</span><span class="p">][</span><span class="s1">'kernel'</span><span class="p">]))</span> <span class="nb">print</span><span class="p">(</span><span class="nb">type</span><span class="p">(</span><span class="n">initialized_state</span><span class="o">.</span><span class="n">params</span><span class="p">[</span><span class="s1">'DotReluDot_0'</span><span class="p">][</span><span class="s1">'Dense_0'</span><span class="p">][</span><span class="s1">'kernel'</span><span class="p">]</span><span class="o">.</span><span class="n">value</span><span class="p">))</span> <span class="nb">print</span><span class="p">(</span><span class="n">initialized_state</span><span class="o">.</span><span class="n">params</span><span class="p">[</span><span class="s1">'DotReluDot_0'</span><span class="p">][</span><span class="s1">'Dense_0'</span><span class="p">][</span><span class="s1">'kernel'</span><span class="p">]</span><span class="o">.</span><span class="n">names</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="n">initialized_state</span><span class="o">.</span><span class="n">params</span><span class="p">[</span><span class="s1">'DotReluDot_0'</span><span class="p">][</span><span class="s1">'Dense_0'</span><span class="p">][</span><span class="s1">'kernel'</span><span class="p">]</span><span class="o">.</span><span class="n">value</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> </pre></div> </div> </div> <div class="cell_output docutils container"> <div class="output stream highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span><class 'flax.core.meta.Partitioned'> <class 'jaxlib.xla_extension.ArrayImpl'> (None, 'model') (1024, 1024) </pre></div> </div> </div> </div> <div class="cell docutils container"> <div class="cell_input docutils container"> <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># Say for some unknown reason you want to make the whole param tree all-zero</span> <span class="n">unboxed_params</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">meta</span><span class="o">.</span><span class="n">unbox</span><span class="p">(</span><span class="n">initialized_state</span><span class="o">.</span><span class="n">params</span><span class="p">)</span> <span class="n">all_zero</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">tree</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">jnp</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">,</span> <span class="n">unboxed_params</span><span class="p">)</span> <span class="n">all_zero_params</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">meta</span><span class="o">.</span><span class="n">replace_boxed</span><span class="p">(</span><span class="n">initialized_state</span><span class="o">.</span><span class="n">params</span><span class="p">,</span> <span class="n">all_zero</span><span class="p">)</span> <span class="k">assert</span> <span class="n">jnp</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">meta</span><span class="o">.</span><span class="n">unbox</span><span class="p">(</span><span class="n">all_zero_params</span><span class="p">[</span><span class="s1">'DotReluDot_0'</span><span class="p">][</span><span class="s1">'Dense_0'</span><span class="p">][</span><span class="s1">'kernel'</span><span class="p">]))</span> <span class="o">==</span> <span class="mi">0</span> </pre></div> </div> </div> </div> <p>You can also check the underlying <a class="reference external" href="https://jax.readthedocs.io/en/latest/jax.sharding.html"><code class="docutils literal notranslate"><span class="pre">jax.sharding</span></code></a> of each parameter, which is now more internal than <code class="docutils literal notranslate"><span class="pre">NamedSharding</span></code>. Note that numbers like <code class="docutils literal notranslate"><span class="pre">initialized_state.step</span></code> are replicated across all devices.</p> <div class="cell docutils container"> <div class="cell_input docutils container"> <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">initialized_state</span><span class="o">.</span><span class="n">params</span><span class="p">[</span><span class="s1">'DotReluDot_0'</span><span class="p">][</span><span class="s1">'Dense_0'</span><span class="p">][</span><span class="s1">'kernel'</span><span class="p">]</span><span class="o">.</span><span class="n">value</span><span class="o">.</span><span class="n">sharding</span> </pre></div> </div> </div> <div class="cell_output docutils container"> <div class="output text_plain highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(None, 'model'), memory_kind=unpinned_host) </pre></div> </div> </div> </div> <div class="cell docutils container"> <div class="cell_input docutils container"> <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="nb">print</span><span class="p">(</span><span class="n">initialized_state</span><span class="o">.</span><span class="n">step</span><span class="p">)</span> <span class="n">initialized_state</span><span class="o">.</span><span class="n">step</span><span class="o">.</span><span class="n">sharding</span> </pre></div> </div> </div> <div class="cell_output docutils container"> <div class="output stream highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>0 </pre></div> </div> <div class="output text_plain highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>NamedSharding(mesh=Mesh('data': 2, 'model': 4), spec=PartitionSpec(), memory_kind=unpinned_host) </pre></div> </div> </div> </div> <p>You can use <a class="reference external" href="https://jax.readthedocs.io/en/latest/_autosummary/jax.tree_util.tree_map.html"><code class="docutils literal notranslate"><span class="pre">jax.tree_util.tree_map</span></code></a> to perform mass computation on a dict of boxed params, in the same way as on a dict of JAX arrays.</p> <div class="cell docutils container"> <div class="cell_input docutils container"> <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">diff</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">tree_util</span><span class="o">.</span><span class="n">tree_map</span><span class="p">(</span> <span class="k">lambda</span> <span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">:</span> <span class="n">a</span> <span class="o">-</span> <span class="n">b</span><span class="p">,</span> <span class="n">initialized_state</span><span class="o">.</span><span class="n">params</span><span class="p">[</span><span class="s1">'DotReluDot_0'</span><span class="p">],</span> <span class="n">initialized_state</span><span class="o">.</span><span class="n">params</span><span class="p">[</span><span class="s1">'DotReluDot_0'</span><span class="p">])</span> <span class="nb">print</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">tree_util</span><span class="o">.</span><span class="n">tree_map</span><span class="p">(</span><span class="n">jnp</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">diff</span><span class="p">))</span> <span class="n">diff_array</span> <span class="o">=</span> <span class="n">diff</span><span class="p">[</span><span class="s1">'Dense_0'</span><span class="p">][</span><span class="s1">'kernel'</span><span class="p">]</span><span class="o">.</span><span class="n">value</span> <span class="nb">print</span><span class="p">(</span><span class="nb">type</span><span class="p">(</span><span class="n">diff_array</span><span class="p">))</span> <span class="nb">print</span><span class="p">(</span><span class="n">diff_array</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> </pre></div> </div> </div> <div class="cell_output docutils container"> <div class="output stream highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>{'Dense_0': {'kernel': Partitioned(value=(1024, 1024), names=(None, 'model'), mesh=None)}, 'W2': Partitioned(value=(1024, 1024), names=('model', None), mesh=None)} <class 'jaxlib.xla_extension.ArrayImpl'> (1024, 1024) </pre></div> </div> </div> </div> </div> <div class="section" id="compile-the-train-step-and-inference"> <h2>Compile the train step and inference<a class="headerlink" href="#compile-the-train-step-and-inference" title="Permalink to this heading">#</a></h2> <p>Create a <code class="docutils literal notranslate"><span class="pre">jit</span></code>ted training step as follows:</p> <div class="cell docutils container"> <div class="cell_input docutils container"> <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="nd">@functools</span><span class="o">.</span><span class="n">partial</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">jit</span><span class="p">,</span> <span class="n">in_shardings</span><span class="o">=</span><span class="p">(</span><span class="n">state_sharding</span><span class="p">,</span> <span class="n">x_sharding</span><span class="p">),</span> <span class="n">out_shardings</span><span class="o">=</span><span class="n">state_sharding</span><span class="p">)</span> <span class="k">def</span> <span class="nf">train_step</span><span class="p">(</span><span class="n">state</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="c1"># A fake loss function.</span> <span class="k">def</span> <span class="nf">loss_unrolled</span><span class="p">(</span><span class="n">params</span><span class="p">):</span> <span class="n">y</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">apply</span><span class="p">({</span><span class="s1">'params'</span><span class="p">:</span> <span class="n">params</span><span class="p">},</span> <span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">y</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span> <span class="n">grad_fn</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">grad</span><span class="p">(</span><span class="n">loss_unrolled</span><span class="p">)</span> <span class="n">grads</span> <span class="o">=</span> <span class="n">grad_fn</span><span class="p">(</span><span class="n">state</span><span class="o">.</span><span class="n">params</span><span class="p">)</span> <span class="n">state</span> <span class="o">=</span> <span class="n">state</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span><span class="n">grads</span><span class="o">=</span><span class="n">grads</span><span class="p">)</span> <span class="k">return</span> <span class="n">state</span> <span class="k">with</span> <span class="n">mesh</span><span class="p">:</span> <span class="n">new_state</span> <span class="o">=</span> <span class="n">train_step</span><span class="p">(</span><span class="n">initialized_state</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> </pre></div> </div> </div> </div> <div class="cell docutils container"> <div class="cell_input docutils container"> <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s1">'Sharding of Weight 1:'</span><span class="p">)</span> <span class="n">jax</span><span class="o">.</span><span class="n">debug</span><span class="o">.</span><span class="n">visualize_array_sharding</span><span class="p">(</span><span class="n">initialized_state</span><span class="o">.</span><span class="n">params</span><span class="p">[</span><span class="s1">'DotReluDot_0'</span><span class="p">][</span><span class="s1">'Dense_0'</span><span class="p">][</span><span class="s1">'kernel'</span><span class="p">]</span><span class="o">.</span><span class="n">value</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s1">'Sharding of Weight 2:'</span><span class="p">)</span> <span class="n">jax</span><span class="o">.</span><span class="n">debug</span><span class="o">.</span><span class="n">visualize_array_sharding</span><span class="p">(</span><span class="n">initialized_state</span><span class="o">.</span><span class="n">params</span><span class="p">[</span><span class="s1">'DotReluDot_0'</span><span class="p">][</span><span class="s1">'W2'</span><span class="p">]</span><span class="o">.</span><span class="n">value</span><span class="p">)</span> </pre></div> </div> </div> <div class="cell_output docutils container"> <div class="output stream highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>Sharding of Weight 1: </pre></div> </div> <div class="output text_html"><pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a"> </span><span style="color: #000000; text-decoration-color: #000000; background-color: #b5cf6b"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a"> </span><span style="color: #000000; text-decoration-color: #000000; background-color: #b5cf6b"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a"> </span><span style="color: #000000; text-decoration-color: #000000; background-color: #b5cf6b"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a"> </span><span style="color: #000000; text-decoration-color: #000000; background-color: #b5cf6b"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a"> </span><span style="color: #000000; text-decoration-color: #000000; background-color: #b5cf6b"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> CPU 0,4 </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6"> CPU 1,5 </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a"> CPU 2,6 </span><span style="color: #000000; text-decoration-color: #000000; background-color: #b5cf6b"> CPU 3,7 </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a"> </span><span style="color: #000000; text-decoration-color: #000000; background-color: #b5cf6b"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a"> </span><span style="color: #000000; text-decoration-color: #000000; background-color: #b5cf6b"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a"> </span><span style="color: #000000; text-decoration-color: #000000; background-color: #b5cf6b"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a"> </span><span style="color: #000000; text-decoration-color: #000000; background-color: #b5cf6b"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a"> </span><span style="color: #000000; text-decoration-color: #000000; background-color: #b5cf6b"> </span> </pre> </div><div class="output stream highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>Sharding of Weight 2: </pre></div> </div> <div class="output text_html"><pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> CPU 0,4 </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6"> CPU 1,5 </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a"> CPU 2,6 </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a"> </span> <span style="color: #000000; text-decoration-color: #000000; background-color: #b5cf6b"> </span> <span style="color: #000000; text-decoration-color: #000000; background-color: #b5cf6b"> CPU 3,7 </span> <span style="color: #000000; text-decoration-color: #000000; background-color: #b5cf6b"> </span> </pre> </div></div> </div> <p>Then, create a compiled inference step. Note that the output is also sharded along <code class="docutils literal notranslate"><span class="pre">(data,</span> <span class="pre">None)</span></code>.</p> <div class="cell docutils container"> <div class="cell_input docutils container"> <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="nd">@functools</span><span class="o">.</span><span class="n">partial</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">jit</span><span class="p">,</span> <span class="n">in_shardings</span><span class="o">=</span><span class="p">(</span><span class="n">state_sharding</span><span class="p">,</span> <span class="n">x_sharding</span><span class="p">),</span> <span class="n">out_shardings</span><span class="o">=</span><span class="n">x_sharding</span><span class="p">)</span> <span class="k">def</span> <span class="nf">apply_fn</span><span class="p">(</span><span class="n">state</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="k">return</span> <span class="n">state</span><span class="o">.</span><span class="n">apply_fn</span><span class="p">({</span><span class="s1">'params'</span><span class="p">:</span> <span class="n">state</span><span class="o">.</span><span class="n">params</span><span class="p">},</span> <span class="n">x</span><span class="p">)</span> <span class="k">with</span> <span class="n">mesh</span><span class="p">:</span> <span class="n">y</span> <span class="o">=</span> <span class="n">apply_fn</span><span class="p">(</span><span class="n">new_state</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="nb">type</span><span class="p">(</span><span class="n">y</span><span class="p">))</span> <span class="nb">print</span><span class="p">(</span><span class="n">y</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="n">y</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="n">jax</span><span class="o">.</span><span class="n">debug</span><span class="o">.</span><span class="n">visualize_array_sharding</span><span class="p">(</span><span class="n">y</span><span class="p">)</span> </pre></div> </div> </div> <div class="cell_output docutils container"> <div class="output stream highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span><class 'jaxlib.xla_extension.ArrayImpl'> float32 (8, 1024) </pre></div> </div> <div class="output text_html"><pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> CPU 0,1,2,3 </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> CPU 4,5,6,7 </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span> </pre> </div></div> </div> </div> <div class="section" id="profiling"> <h2>Profiling<a class="headerlink" href="#profiling" title="Permalink to this heading">#</a></h2> <p>If you are running on a TPU pod or a pod slice, you can use a custom <code class="docutils literal notranslate"><span class="pre">block_all</span></code> utility function, as defined below, to measure the performance:</p> <div class="cell docutils container"> <div class="cell_input docutils container"> <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="o">%%time</span>it <span class="k">def</span> <span class="nf">block_all</span><span class="p">(</span><span class="n">xs</span><span class="p">):</span> <span class="n">jax</span><span class="o">.</span><span class="n">tree_util</span><span class="o">.</span><span class="n">tree_map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">x</span><span class="o">.</span><span class="n">block_until_ready</span><span class="p">(),</span> <span class="n">xs</span><span class="p">)</span> <span class="k">return</span> <span class="n">xs</span> <span class="k">with</span> <span class="n">mesh</span><span class="p">:</span> <span class="n">new_state</span> <span class="o">=</span> <span class="n">block_all</span><span class="p">(</span><span class="n">train_step</span><span class="p">(</span><span class="n">initialized_state</span><span class="p">,</span> <span class="n">x</span><span class="p">))</span> </pre></div> </div> </div> <div class="cell_output docutils container"> <div class="output stream highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>237 ms ± 9.08 ms per loop (mean ± std. dev. of 7 runs, 1 loop each) </pre></div> </div> </div> </div> </div> <div class="section" id="logical-axis-annotation"> <h2>Logical axis annotation<a class="headerlink" href="#logical-axis-annotation" title="Permalink to this heading">#</a></h2> <p>JAX’s automatic SPMD encourages users to explore different sharding layouts to find the optimal one. To this end, in Flax you actually can annotate the dimensions of any data with more descriptive axis names (not just device mesh axis names like <code class="docutils literal notranslate"><span class="pre">'data'</span></code> and <code class="docutils literal notranslate"><span class="pre">'model'</span></code>).</p> <p>The <code class="docutils literal notranslate"><span class="pre">LogicalDotReluDot</span></code> and <code class="docutils literal notranslate"><span class="pre">LogicalMLP</span></code> Module definition below are similar to the Modules you created earlier, except for the following:</p> <ol class="arabic simple"> <li><p>All axes are annotated with more concrete, meaningful names, such as <code class="docutils literal notranslate"><span class="pre">'embed'</span></code>, <code class="docutils literal notranslate"><span class="pre">'hidden'</span></code>, <code class="docutils literal notranslate"><span class="pre">'batch'</span></code> and <code class="docutils literal notranslate"><span class="pre">'layer'</span></code>. These names are referred to as <em>logical axis names</em> in Flax. They make the dimensional changes inside model definitions more readable.</p></li> <li><p><a class="reference external" href="https://flax.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.with_logical_partitioning"><code class="docutils literal notranslate"><span class="pre">flax.linen.with_logical_partitioning</span></code></a> replaces <code class="docutils literal notranslate"><span class="pre">flax.linen.with_partitioning</span></code>; and <a class="reference external" href="https://flax.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.with_logical_constraint"><code class="docutils literal notranslate"><span class="pre">flax.linen.with_logical_constraint</span></code></a> replaces <code class="docutils literal notranslate"><span class="pre">jax.lax.with_sharding_constraint</span></code>, to recognize the logical axis names.</p></li> </ol> <div class="cell docutils container"> <div class="cell_input docutils container"> <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">LogicalDotReluDot</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="n">depth</span><span class="p">:</span> <span class="nb">int</span> <span class="n">dense_init</span><span class="p">:</span> <span class="n">Callable</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">initializers</span><span class="o">.</span><span class="n">xavier_normal</span><span class="p">()</span> <span class="nd">@nn</span><span class="o">.</span><span class="n">compact</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="n">y</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">depth</span><span class="p">,</span> <span class="n">kernel_init</span><span class="o">=</span><span class="n">nn</span><span class="o">.</span><span class="n">with_logical_partitioning</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dense_init</span><span class="p">,</span> <span class="p">(</span><span class="s1">'embed'</span><span class="p">,</span> <span class="s1">'hidden'</span><span class="p">)),</span> <span class="n">use_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="c1"># or overwrite with `bias_init`</span> <span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">y</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="n">y</span><span class="p">)</span> <span class="c1"># Force a local sharding annotation.</span> <span class="n">y</span> <span class="o">=</span> <span class="n">with_sharding_constraint</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="n">mesh_sharding</span><span class="p">(</span><span class="n">PartitionSpec</span><span class="p">(</span><span class="s1">'data'</span><span class="p">,</span> <span class="s1">'model'</span><span class="p">)))</span> <span class="n">W2</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">param</span><span class="p">(</span> <span class="s1">'W2'</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">with_logical_partitioning</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dense_init</span><span class="p">,</span> <span class="p">(</span><span class="s1">'hidden'</span><span class="p">,</span> <span class="s1">'embed'</span><span class="p">)),</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">depth</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]))</span> <span class="n">z</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="n">W2</span><span class="p">)</span> <span class="c1"># Force a local sharding annotation.</span> <span class="n">z</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">with_logical_constraint</span><span class="p">(</span><span class="n">z</span><span class="p">,</span> <span class="p">(</span><span class="s1">'batch'</span><span class="p">,</span> <span class="s1">'embed'</span><span class="p">))</span> <span class="k">return</span> <span class="n">z</span><span class="p">,</span> <span class="kc">None</span> <span class="k">class</span> <span class="nc">LogicalMLP</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="n">num_layers</span><span class="p">:</span> <span class="nb">int</span> <span class="n">depth</span><span class="p">:</span> <span class="nb">int</span> <span class="n">use_scan</span><span class="p">:</span> <span class="nb">bool</span> <span class="nd">@nn</span><span class="o">.</span><span class="n">compact</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">use_scan</span><span class="p">:</span> <span class="n">x</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">scan</span><span class="p">(</span><span class="n">LogicalDotReluDot</span><span class="p">,</span> <span class="n">length</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">,</span> <span class="n">variable_axes</span><span class="o">=</span><span class="p">{</span><span class="s2">"params"</span><span class="p">:</span> <span class="mi">0</span><span class="p">},</span> <span class="n">split_rngs</span><span class="o">=</span><span class="p">{</span><span class="s2">"params"</span><span class="p">:</span> <span class="kc">True</span><span class="p">},</span> <span class="n">metadata_params</span><span class="o">=</span><span class="p">{</span><span class="n">nn</span><span class="o">.</span><span class="n">PARTITION_NAME</span><span class="p">:</span> <span class="s1">'layer'</span><span class="p">}</span> <span class="p">)(</span><span class="bp">self</span><span class="o">.</span><span class="n">depth</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="k">else</span><span class="p">:</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">):</span> <span class="n">x</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">LogicalDotReluDot</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">depth</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">x</span> </pre></div> </div> </div> </div> <p>Now, initiate a model and try to figure out what sharding its <code class="docutils literal notranslate"><span class="pre">state</span></code> should have.</p> <p>To allow the device mesh to take your model correctly, you need to decide which of these logical axis names are mapped to the device axis <code class="docutils literal notranslate"><span class="pre">'data'</span></code> or <code class="docutils literal notranslate"><span class="pre">'model'</span></code>. This rule is a list of (<code class="docutils literal notranslate"><span class="pre">logical_axis_name</span></code>, <code class="docutils literal notranslate"><span class="pre">device_axis_name</span></code>) tuples, and <a class="reference external" href="https://flax.readthedocs.io/en/latest/api_reference/flax.linen/spmd.html#flax.linen.logical_to_mesh_sharding"><code class="docutils literal notranslate"><span class="pre">flax.linen.logical_to_mesh_sharding</span></code></a> will convert them to the kind of sharding that the device mesh can understand.</p> <p>This allows you to change the rules and try out new partition layouts without modifying the model definition.</p> <div class="cell docutils container"> <div class="cell_input docutils container"> <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="c1"># Unspecified rule means unsharded by default, so no need to specify `('embed', None)` and `('layer', None)`.</span> <span class="n">rules</span> <span class="o">=</span> <span class="p">((</span><span class="s1">'batch'</span><span class="p">,</span> <span class="s1">'data'</span><span class="p">),</span> <span class="p">(</span><span class="s1">'hidden'</span><span class="p">,</span> <span class="s1">'model'</span><span class="p">))</span> <span class="n">logical_model</span> <span class="o">=</span> <span class="n">LogicalMLP</span><span class="p">(</span><span class="n">LAYERS</span><span class="p">,</span> <span class="n">DEPTH</span><span class="p">,</span> <span class="n">USE_SCAN</span><span class="p">)</span> <span class="n">logical_abstract_variables</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">eval_shape</span><span class="p">(</span> <span class="n">functools</span><span class="o">.</span><span class="n">partial</span><span class="p">(</span><span class="n">init_fn</span><span class="p">,</span> <span class="n">model</span><span class="o">=</span><span class="n">logical_model</span><span class="p">,</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">optimizer</span><span class="p">),</span> <span class="n">k</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="n">logical_state_spec</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">get_partition_spec</span><span class="p">(</span><span class="n">logical_abstract_variables</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s1">'annotations are logical, not mesh-specific: '</span><span class="p">,</span> <span class="n">logical_state_spec</span><span class="o">.</span><span class="n">params</span><span class="p">[</span><span class="s1">'LogicalDotReluDot_0'</span><span class="p">][</span><span class="s1">'Dense_0'</span><span class="p">][</span><span class="s1">'kernel'</span><span class="p">])</span> <span class="n">logical_state_sharding</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">logical_to_mesh_sharding</span><span class="p">(</span><span class="n">logical_state_spec</span><span class="p">,</span> <span class="n">mesh</span><span class="p">,</span> <span class="n">rules</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s1">'sharding annotations are mesh-specific: '</span><span class="p">,</span> <span class="n">logical_state_sharding</span><span class="o">.</span><span class="n">params</span><span class="p">[</span><span class="s1">'LogicalDotReluDot_0'</span><span class="p">][</span><span class="s1">'Dense_0'</span><span class="p">][</span><span class="s1">'kernel'</span><span class="p">]</span><span class="o">.</span><span class="n">spec</span><span class="p">)</span> </pre></div> </div> </div> <div class="cell_output docutils container"> <div class="output stream highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>annotations are logical, not mesh-specific: PartitionSpec('embed', 'hidden') sharding annotations are mesh-specific: PartitionSpec(None, 'model') </pre></div> </div> </div> </div> <p>You can verify that the <code class="docutils literal notranslate"><span class="pre">logical_state_spec</span></code> here has the same content as <code class="docutils literal notranslate"><span class="pre">state_spec</span></code> in the previous (“non-logical”) example. This allows you to <code class="docutils literal notranslate"><span class="pre">jax.jit</span></code> your Module’s <a class="reference external" href="https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.init"><code class="docutils literal notranslate"><span class="pre">flax.linen.Module.init</span></code></a> and <a class="reference external" href="https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.Module.apply"><code class="docutils literal notranslate"><span class="pre">flax.linen.Module.apply</span></code></a> the same way in the above above.</p> <div class="cell docutils container"> <div class="cell_input docutils container"> <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">state_sharding</span><span class="o">.</span><span class="n">params</span><span class="p">[</span><span class="s1">'DotReluDot_0'</span><span class="p">]</span> <span class="o">==</span> <span class="n">logical_state_sharding</span><span class="o">.</span><span class="n">params</span><span class="p">[</span><span class="s1">'LogicalDotReluDot_0'</span><span class="p">]</span> </pre></div> </div> </div> <div class="cell_output docutils container"> <div class="output text_plain highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>True </pre></div> </div> </div> </div> <div class="cell docutils container"> <div class="cell_input docutils container"> <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="n">logical_jit_init_fn</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">jit</span><span class="p">(</span><span class="n">init_fn</span><span class="p">,</span> <span class="n">static_argnums</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">in_shardings</span><span class="o">=</span><span class="p">(</span><span class="n">mesh_sharding</span><span class="p">(()),</span> <span class="n">x_sharding</span><span class="p">),</span> <span class="c1"># PRNG key and x</span> <span class="n">out_shardings</span><span class="o">=</span><span class="n">logical_state_sharding</span><span class="p">)</span> <span class="n">logical_initialized_state</span> <span class="o">=</span> <span class="n">logical_jit_init_fn</span><span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">logical_model</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">)</span> </pre></div> </div> </div> </div> <div class="cell docutils container"> <div class="cell_input docutils container"> <div class="highlight-ipython3 notranslate"><div class="highlight"><pre><span></span><span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s1">'Sharding of Weight 1:'</span><span class="p">)</span> <span class="n">jax</span><span class="o">.</span><span class="n">debug</span><span class="o">.</span><span class="n">visualize_array_sharding</span><span class="p">(</span><span class="n">logical_initialized_state</span><span class="o">.</span><span class="n">params</span><span class="p">[</span><span class="s1">'LogicalDotReluDot_0'</span><span class="p">][</span><span class="s1">'Dense_0'</span><span class="p">][</span><span class="s1">'kernel'</span><span class="p">]</span><span class="o">.</span><span class="n">value</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s1">'Sharding of Weight 2:'</span><span class="p">)</span> <span class="n">jax</span><span class="o">.</span><span class="n">debug</span><span class="o">.</span><span class="n">visualize_array_sharding</span><span class="p">(</span><span class="n">logical_initialized_state</span><span class="o">.</span><span class="n">params</span><span class="p">[</span><span class="s1">'LogicalDotReluDot_0'</span><span class="p">][</span><span class="s1">'W2'</span><span class="p">]</span><span class="o">.</span><span class="n">value</span><span class="p">)</span> </pre></div> </div> </div> <div class="cell_output docutils container"> <div class="output stream highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>Sharding of Weight 1: </pre></div> </div> <div class="output text_html"><pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a"> </span><span style="color: #000000; text-decoration-color: #000000; background-color: #b5cf6b"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a"> </span><span style="color: #000000; text-decoration-color: #000000; background-color: #b5cf6b"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a"> </span><span style="color: #000000; text-decoration-color: #000000; background-color: #b5cf6b"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a"> </span><span style="color: #000000; text-decoration-color: #000000; background-color: #b5cf6b"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a"> </span><span style="color: #000000; text-decoration-color: #000000; background-color: #b5cf6b"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> CPU 0,4 </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6"> CPU 1,5 </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a"> CPU 2,6 </span><span style="color: #000000; text-decoration-color: #000000; background-color: #b5cf6b"> CPU 3,7 </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a"> </span><span style="color: #000000; text-decoration-color: #000000; background-color: #b5cf6b"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a"> </span><span style="color: #000000; text-decoration-color: #000000; background-color: #b5cf6b"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a"> </span><span style="color: #000000; text-decoration-color: #000000; background-color: #b5cf6b"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a"> </span><span style="color: #000000; text-decoration-color: #000000; background-color: #b5cf6b"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6"> </span><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a"> </span><span style="color: #000000; text-decoration-color: #000000; background-color: #b5cf6b"> </span> </pre> </div><div class="output stream highlight-myst-ansi notranslate"><div class="highlight"><pre><span></span>Sharding of Weight 2: </pre></div> </div> <div class="output text_html"><pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> CPU 0,4 </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #393b79"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6"> CPU 1,5 </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #de9ed6"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a"> </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a"> CPU 2,6 </span> <span style="color: #ffffff; text-decoration-color: #ffffff; background-color: #ad494a"> </span> <span style="color: #000000; text-decoration-color: #000000; background-color: #b5cf6b"> </span> <span style="color: #000000; text-decoration-color: #000000; background-color: #b5cf6b"> CPU 3,7 </span> <span style="color: #000000; text-decoration-color: #000000; background-color: #b5cf6b"> </span> </pre> </div></div> </div> </div> <div class="section" id="when-to-use-device-axis-logical-axis"> <h2>When to use device axis / logical axis<a class="headerlink" href="#when-to-use-device-axis-logical-axis" title="Permalink to this heading">#</a></h2> <p>Choosing when to use a device or logical axis depends on how much you want to control the partitioning of your model:</p> <ul class="simple"> <li><p><strong>Device mesh axis</strong>: If you want a very simple model, or you are very confident of your way of partitioning, defining it with <strong>device mesh axis</strong> can potentially save you a few extra lines of code of converting the logical naming back to the device naming.</p></li> <li><p><strong>Logical naming</strong>: On the other hand, the <strong>logical naming</strong> helpers can be useful for exploring different sharding layouts. Use this if you want to experiment around and find the most optimal partition layout for your model.</p></li> <li><p><strong>Device axis names</strong>: In really advanced use cases, you may have more complicated sharding patterns that require annotating <em>activation</em> dimension names differently from <em>parameter</em> dimension names. If you wish to have more fine-grained control on manual mesh assignments, directly using <strong>device axis names</strong> could be more helpful.</p></li> </ul> </div> <div class="section" id="save-the-data"> <h2>Save the data<a class="headerlink" href="#save-the-data" title="Permalink to this heading">#</a></h2> <p>To save the cross-device array, you can use <a class="reference external" href="https://flax.readthedocs.io/en/latest/_modules/flax/training/checkpoints.html"><code class="docutils literal notranslate"><span class="pre">flax.training.checkpoints</span></code></a>, as shown in the <a class="reference external" href="https://flax.readthedocs.io/en/latest/guides/training_techniques/use_checkpointing.html#multi-host-multi-process-checkpointing">Save and load checkpoints guide - Multi-host/multi-process checkpointing</a>. This is especially required if you are running on a multi-host environment (for example, a TPU pod).</p> <p>In practice, you might want to save the raw <code class="docutils literal notranslate"><span class="pre">jax.Array</span></code> pytree as checkpoint, instead of the wrapped <code class="docutils literal notranslate"><span class="pre">Partitioned</span></code> values, to reduce complexity. You can restore it as-is and put it back into an annotated pytree with <code class="docutils literal notranslate"><span class="pre">flax.linen.meta.replace_boxed()</span></code>.</p> <p>Keep in mind that to restore the arrays to the desired partition, you need to provide a sample <code class="docutils literal notranslate"><span class="pre">target</span></code> pytree that has the same structure and has the desired <a class="reference external" href="https://jax.readthedocs.io/en/latest/jax.sharding.html#jax.sharding.Sharding"><code class="docutils literal notranslate"><span class="pre">jax.sharding.Sharding</span></code></a> in place for each JAX array. The sharding you use to restore the array doesn’t necessarily need to be the same as the ones you used to store the array.</p> </div> </div> </article> <footer class="prev-next-footer d-print-none"> <div class="prev-next-area"> <a class="left-prev" href="ensembling.html" title="previous page"> <i class="fa-solid fa-angle-left"></i> <div class="prev-next-info"> <p class="prev-next-subtitle">previous</p> <p class="prev-next-title">Ensembling on multiple devices</p> </div> </a> <a class="right-next" href="../model_inspection/index.html" title="next page"> <div class="prev-next-info"> <p class="prev-next-subtitle">next</p> <p class="prev-next-title">Model inspection</p> </div> <i class="fa-solid fa-angle-right"></i> </a> </div> </footer> </div> <div class="bd-sidebar-secondary bd-toc"><div class="sidebar-secondary-items sidebar-secondary__inner"> <div class="sidebar-secondary-item"> <div class="page-toc tocsection onthispage"> <i class="fa-solid fa-list"></i> Contents </div> <nav class="bd-toc-nav page-toc"> <ul class="visible nav section-nav flex-column"> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax-and-jax-jit-scaled-up">Flax and <code class="docutils literal notranslate"><span class="pre">jax.jit</span></code> scaled up</a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#setup">Setup</a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#define-a-layer">Define a layer</a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#define-a-model-with-flax-linen-scan-lifted-transformation">Define a model with <code class="docutils literal notranslate"><span class="pre">flax.linen.scan</span></code> lifted transformation</a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#specify-sharding">Specify sharding</a><ul class="nav section-nav flex-column"> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#the-input-s-sharding">The input’s sharding</a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#the-output-s-sharding">The output’s sharding</a></li> </ul> </li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#compile-the-code">Compile the code</a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#inspect-the-module-output">Inspect the Module output</a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#compile-the-train-step-and-inference">Compile the train step and inference</a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#profiling">Profiling</a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#logical-axis-annotation">Logical axis annotation</a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#when-to-use-device-axis-logical-axis">When to use device axis / logical axis</a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#save-the-data">Save the data</a></li> </ul> </nav></div> </div></div> </div> <footer class="bd-footer-content"> <div class="bd-footer-content__inner container"> <div class="footer-item"> <p class="component-author"> By The Flax authors </p> </div> <div class="footer-item"> <p class="copyright"> © Copyright 2023, The Flax authors. <br/> </p> </div> <div class="footer-item"> </div> <div class="footer-item"> </div> </div> </footer> </main> </div> </div> <!-- Scripts loaded after <body> so the DOM is not blocked --> <script src="../../_static/scripts/bootstrap.js?digest=dfe6caa3a7d634c4db9b"></script> <script src="../../_static/scripts/pydata-sphinx-theme.js?digest=dfe6caa3a7d634c4db9b"></script> <footer class="bd-footer"> </footer> </body> </html>