CINXE.COM
jax.sharding module — JAX documentation
<!DOCTYPE html> <html lang="en" data-content_root="./" > <head> <meta charset="utf-8" /> <meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="viewport" content="width=device-width, initial-scale=1" /> <title>jax.sharding module — JAX documentation</title> <script data-cfasync="false"> document.documentElement.dataset.mode = localStorage.getItem("mode") || ""; document.documentElement.dataset.theme = localStorage.getItem("theme") || "light"; </script> <!-- Loaded before other Sphinx assets --> <link href="_static/styles/theme.css?digest=5b4479735964841361fd" rel="stylesheet" /> <link href="_static/styles/bootstrap.css?digest=5b4479735964841361fd" rel="stylesheet" /> <link href="_static/styles/pydata-sphinx-theme.css?digest=5b4479735964841361fd" rel="stylesheet" /> <link href="_static/vendor/fontawesome/6.1.2/css/all.min.css?digest=5b4479735964841361fd" rel="stylesheet" /> <link rel="preload" as="font" type="font/woff2" crossorigin href="_static/vendor/fontawesome/6.1.2/webfonts/fa-solid-900.woff2" /> <link rel="preload" as="font" type="font/woff2" crossorigin href="_static/vendor/fontawesome/6.1.2/webfonts/fa-brands-400.woff2" /> <link rel="preload" as="font" type="font/woff2" crossorigin href="_static/vendor/fontawesome/6.1.2/webfonts/fa-regular-400.woff2" /> <link rel="stylesheet" type="text/css" href="_static/pygments.css?v=8f2a1f02" /> <link rel="stylesheet" type="text/css" href="_static/styles/sphinx-book-theme.css?v=384b581d" /> <link rel="stylesheet" type="text/css" href="_static/plot_directive.css" /> <link rel="stylesheet" type="text/css" href="_static/mystnb.4510f1fc1dee50b3e5859aac5469c37c29e427902b24a333a5f9fcb2f0b3ac41.css" /> <link rel="stylesheet" type="text/css" href="_static/copybutton.css?v=76b2166b" /> <link rel="stylesheet" type="text/css" href="_static/sphinx-design.min.css?v=95c83b7e" /> <link rel="stylesheet" type="text/css" href="_static/style.css?v=05b22d1f" /> <link rel="stylesheet" href="_static/style.css" type="text/css" /> <!-- Pre-loaded scripts that we'll load fully later --> <link rel="preload" as="script" href="_static/scripts/bootstrap.js?digest=5b4479735964841361fd" /> <link rel="preload" as="script" href="_static/scripts/pydata-sphinx-theme.js?digest=5b4479735964841361fd" /> <script src="_static/vendor/fontawesome/6.1.2/js/all.min.js?digest=5b4479735964841361fd"></script> <script src="_static/documentation_options.js?v=9eb32ce0"></script> <script src="_static/doctools.js?v=9a2dae69"></script> <script src="_static/sphinx_highlight.js?v=dc90522c"></script> <script src="_static/clipboard.min.js?v=a7894cd8"></script> <script src="_static/copybutton.js?v=30646c52"></script> <script src="_static/scripts/sphinx-book-theme.js?v=efea14e4"></script> <script src="_static/design-tabs.js?v=f930bc37"></script> <script>DOCUMENTATION_OPTIONS.pagename = 'jax.sharding';</script> <link rel="icon" href="_static/favicon.png"/> <link rel="author" title="About these documents" href="about.html" /> <link rel="index" title="Index" href="genindex.html" /> <link rel="search" title="Search" href="search.html" /> <link rel="next" title="jax.debug module" href="jax.debug.html" /> <link rel="prev" title="jax.random.weibull_min" href="_autosummary/jax.random.weibull_min.html" /> <meta name="viewport" content="width=device-width, initial-scale=1"/> <meta name="docsearch:language" content="en"/> <script async type="text/javascript" src="/_/static/javascript/readthedocs-addons.js"></script><meta name="readthedocs-project-slug" content="jax" /><meta name="readthedocs-version-slug" content="latest" /><meta name="readthedocs-resolver-filename" content="/jax.sharding.html" /><meta name="readthedocs-http-status" content="200" /></head> <body data-bs-spy="scroll" data-bs-target=".bd-toc-nav" data-offset="180" data-bs-root-margin="0px 0px -60%" data-default-mode=""> <a class="skip-link" href="#main-content">Skip to main content</a> <div id="pst-scroll-pixel-helper"></div> <button type="button" class="btn rounded-pill" id="pst-back-to-top"> <i class="fa-solid fa-arrow-up"></i> Back to top </button> <input type="checkbox" class="sidebar-toggle" name="__primary" id="__primary"/> <label class="overlay overlay-primary" for="__primary"></label> <input type="checkbox" class="sidebar-toggle" name="__secondary" id="__secondary"/> <label class="overlay overlay-secondary" for="__secondary"></label> <div class="search-button__wrapper"> <div class="search-button__overlay"></div> <div class="search-button__search-container"> <form class="bd-search d-flex align-items-center" action="search.html" method="get"> <i class="fa-solid fa-magnifying-glass"></i> <input type="search" class="form-control" name="q" id="search-input" placeholder="Search..." aria-label="Search..." autocomplete="off" autocorrect="off" autocapitalize="off" spellcheck="false"/> <span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd>K</kbd></span> </form></div> </div> <nav class="bd-header navbar navbar-expand-lg bd-navbar"> </nav> <div class="bd-container"> <div class="bd-container__inner bd-page-width"> <div class="bd-sidebar-primary bd-sidebar"> <div class="sidebar-header-items sidebar-primary__section"> </div> <div class="sidebar-primary-items__start sidebar-primary__section"> <div class="sidebar-primary-item"> <a class="navbar-brand logo" href="index.html"> <img src="_static/jax_logo_250px.png" class="logo__image only-light" alt="JAX documentation - Home"/> <script>document.write(`<img src="_static/jax_logo_250px.png" class="logo__image only-dark" alt="JAX documentation - Home"/>`);</script> </a></div> <div class="sidebar-primary-item"> <script> document.write(` <button class="btn navbar-btn search-button-field search-button__button" title="Search" aria-label="Search" data-bs-placement="bottom" data-bs-toggle="tooltip"> <i class="fa-solid fa-magnifying-glass"></i> <span class="search-button__default-text">Search</span> <span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd class="kbd-shortcut__modifier">K</kbd></span> </button> `); </script></div> <div class="sidebar-primary-item"><nav class="bd-links bd-docs-nav" aria-label="Main"> <div class="bd-toc-item navbar-nav active"> <p aria-level="2" class="caption" role="heading"><span class="caption-text">Getting started</span></p> <ul class="nav bd-sidenav"> <li class="toctree-l1"><a class="reference internal" href="installation.html">Installation</a></li> <li class="toctree-l1"><a class="reference internal" href="quickstart.html">Quickstart</a></li> </ul> <ul class="nav bd-sidenav"> <li class="toctree-l1 has-children"><a class="reference internal" href="tutorials.html">Tutorials</a><input class="toctree-checkbox" id="toctree-checkbox-1" name="toctree-checkbox-1" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-1"><i class="fa-solid fa-chevron-down"></i></label><ul> <li class="toctree-l2"><a class="reference internal" href="quickstart.html">Quickstart</a></li> <li class="toctree-l2"><a class="reference internal" href="key-concepts.html">Key concepts</a></li> <li class="toctree-l2"><a class="reference internal" href="jit-compilation.html">Just-in-time compilation</a></li> <li class="toctree-l2"><a class="reference internal" href="automatic-vectorization.html">Automatic vectorization</a></li> <li class="toctree-l2"><a class="reference internal" href="automatic-differentiation.html">Automatic differentiation</a></li> <li class="toctree-l2"><a class="reference internal" href="debugging.html">Introduction to debugging</a></li> <li class="toctree-l2"><a class="reference internal" href="random-numbers.html">Pseudorandom numbers</a></li> <li class="toctree-l2"><a class="reference internal" href="working-with-pytrees.html">Working with pytrees</a></li> <li class="toctree-l2"><a class="reference internal" href="sharded-computation.html">Introduction to parallel programming</a></li> <li class="toctree-l2"><a class="reference internal" href="stateful-computations.html">Stateful computations</a></li> <li class="toctree-l2"><a class="reference internal" href="control-flow.html">Control flow and logical operators with JIT</a></li> <li class="toctree-l2"><a class="reference internal" href="advanced-autodiff.html">Advanced automatic differentiation</a></li> <li class="toctree-l2"><a class="reference internal" href="external-callbacks.html">External callbacks</a></li> <li class="toctree-l2"><a class="reference internal" href="gradient-checkpointing.html">Gradient checkpointing with <code class="docutils literal notranslate"><span class="pre">jax.checkpoint</span></code> (<code class="docutils literal notranslate"><span class="pre">jax.remat</span></code>)</a></li> <li class="toctree-l2"><a class="reference internal" href="jax-primitives.html">JAX Internals: primitives</a></li> <li class="toctree-l2"><a class="reference internal" href="jaxpr.html">JAX internals: The jaxpr language</a></li> </ul> </li> <li class="toctree-l1"><a class="reference internal" href="notebooks/Common_Gotchas_in_JAX.html">🔪 JAX - The Sharp Bits 🔪</a></li> <li class="toctree-l1"><a class="reference internal" href="faq.html">Frequently asked questions (FAQ)</a></li> </ul> <p aria-level="2" class="caption" role="heading"><span class="caption-text">More guides/resources</span></p> <ul class="current nav bd-sidenav"> <li class="toctree-l1 has-children"><a class="reference internal" href="user_guides.html">User guides</a><input class="toctree-checkbox" id="toctree-checkbox-2" name="toctree-checkbox-2" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-2"><i class="fa-solid fa-chevron-down"></i></label><ul> <li class="toctree-l2"><a class="reference internal" href="notebooks/thinking_in_jax.html">How to think in JAX</a></li> <li class="toctree-l2"><a class="reference internal" href="profiling.html">Profiling computation</a></li> <li class="toctree-l2"><a class="reference internal" href="device_memory_profiling.html">Profiling device memory</a></li> <li class="toctree-l2 has-children"><a class="reference internal" href="debugging/index.html">Debugging runtime values</a><input class="toctree-checkbox" id="toctree-checkbox-3" name="toctree-checkbox-3" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-3"><i class="fa-solid fa-chevron-down"></i></label><ul> <li class="toctree-l3"><a class="reference internal" href="debugging/print_breakpoint.html">Compiled prints and breakpoints</a></li> <li class="toctree-l3"><a class="reference internal" href="debugging/checkify_guide.html">The <code class="docutils literal notranslate"><span class="pre">checkify</span></code> transformation</a></li> <li class="toctree-l3"><a class="reference internal" href="debugging/flags.html">JAX debugging flags</a></li> </ul> </li> <li class="toctree-l2"><a class="reference internal" href="gpu_performance_tips.html">GPU performance tips</a></li> <li class="toctree-l2"><a class="reference internal" href="persistent_compilation_cache.html">Persistent compilation cache</a></li> <li class="toctree-l2"><a class="reference internal" href="pytrees.html">Pytrees</a></li> <li class="toctree-l2"><a class="reference internal" href="errors.html">Errors</a></li> <li class="toctree-l2"><a class="reference internal" href="aot.html">Ahead-of-time lowering and compilation</a></li> <li class="toctree-l2 has-children"><a class="reference internal" href="export/index.html">Exporting and serialization</a><input class="toctree-checkbox" id="toctree-checkbox-4" name="toctree-checkbox-4" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-4"><i class="fa-solid fa-chevron-down"></i></label><ul> <li class="toctree-l3"><a class="reference internal" href="export/export.html">Exporting and serializing staged-out computations</a></li> <li class="toctree-l3"><a class="reference internal" href="export/shape_poly.html">Shape polymorphism</a></li> <li class="toctree-l3"><a class="reference internal" href="export/jax2tf.html">Interoperation with TensorFlow</a></li> </ul> </li> <li class="toctree-l2"><a class="reference internal" href="type_promotion.html">Type promotion semantics</a></li> <li class="toctree-l2"><a class="reference internal" href="transfer_guard.html">Transfer guard</a></li> <li class="toctree-l2 has-children"><a class="reference internal" href="pallas/index.html">Pallas: a JAX kernel language</a><input class="toctree-checkbox" id="toctree-checkbox-5" name="toctree-checkbox-5" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-5"><i class="fa-solid fa-chevron-down"></i></label><ul> <li class="toctree-l3"><a class="reference internal" href="pallas/quickstart.html">Pallas Quickstart</a></li> <li class="toctree-l3"><a class="reference internal" href="pallas/grid_blockspec.html">Grids and BlockSpecs</a></li> <li class="toctree-l3 has-children"><a class="reference internal" href="pallas/tpu/index.html">Pallas TPU</a><input class="toctree-checkbox" id="toctree-checkbox-6" name="toctree-checkbox-6" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-6"><i class="fa-solid fa-chevron-down"></i></label><ul> <li class="toctree-l4"><a class="reference internal" href="pallas/tpu/details.html">Writing TPU kernels with Pallas</a></li> <li class="toctree-l4"><a class="reference internal" href="pallas/tpu/pipelining.html">Pipelining</a></li> <li class="toctree-l4"><a class="reference internal" href="pallas/tpu/matmul.html">Matrix Multiplication</a></li> <li class="toctree-l4"><a class="reference internal" href="pallas/tpu/sparse.html">Scalar Prefetch and Block-Sparse Computation</a></li> <li class="toctree-l4"><a class="reference internal" href="pallas/tpu/distributed.html">Distributed Computing in Pallas for TPUs</a></li> </ul> </li> <li class="toctree-l3 has-children"><a class="reference internal" href="pallas/design/index.html">Pallas Design Notes</a><input class="toctree-checkbox" id="toctree-checkbox-7" name="toctree-checkbox-7" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-7"><i class="fa-solid fa-chevron-down"></i></label><ul> <li class="toctree-l4"><a class="reference internal" href="pallas/design/design.html">Pallas Design</a></li> <li class="toctree-l4"><a class="reference internal" href="pallas/design/async_note.html">Pallas Async Operations</a></li> </ul> </li> <li class="toctree-l3"><a class="reference internal" href="pallas/CHANGELOG.html">Pallas Changelog</a></li> </ul> </li> <li class="toctree-l2"><a class="reference internal" href="ffi.html">Foreign function interface (FFI)</a></li> <li class="toctree-l2"><a class="reference internal" href="notebooks/neural_network_with_tfds_data.html">Training a simple neural network, with tensorflow/datasets data loading</a></li> <li class="toctree-l2"><a class="reference internal" href="notebooks/Neural_Network_and_Data_Loading.html">Training a simple neural network, with PyTorch data loading</a></li> <li class="toctree-l2"><a class="reference internal" href="notebooks/vmapped_log_probs.html">Autobatching for Bayesian inference</a></li> </ul> </li> <li class="toctree-l1 has-children"><a class="reference internal" href="advanced_guide.html">Advanced guides</a><input class="toctree-checkbox" id="toctree-checkbox-8" name="toctree-checkbox-8" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-8"><i class="fa-solid fa-chevron-down"></i></label><ul> <li class="toctree-l2"><a class="reference internal" href="notebooks/Distributed_arrays_and_automatic_parallelization.html">Distributed arrays and automatic parallelization</a></li> <li class="toctree-l2"><a class="reference internal" href="notebooks/shard_map.html">Manual parallelism with <code class="docutils literal notranslate"><span class="pre">shard_map</span></code></a></li> <li class="toctree-l2"><a class="reference internal" href="multi_process.html">Multi-host and multi-process environments</a></li> <li class="toctree-l2"><a class="reference internal" href="distributed_data_loading.html">Distributed data loading</a></li> <li class="toctree-l2"><a class="reference internal" href="notebooks/autodiff_cookbook.html">The Autodiff Cookbook</a></li> <li class="toctree-l2"><a class="reference internal" href="notebooks/Custom_derivative_rules_for_Python_code.html">Custom derivative rules</a></li> <li class="toctree-l2"><a class="reference internal" href="notebooks/autodiff_remat.html">Control autodiff’s saved values with <code class="docutils literal notranslate"><span class="pre">jax.checkpoint</span></code> (aka <code class="docutils literal notranslate"><span class="pre">jax.remat</span></code>)</a></li> <li class="toctree-l2"><a class="reference internal" href="notebooks/convolutions.html">Generalized convolutions in JAX</a></li> <li class="toctree-l2"><a class="reference internal" href="xla_flags.html">List of XLA compiler flags</a></li> </ul> </li> <li class="toctree-l1 has-children"><a class="reference internal" href="contributor_guide.html">Developer notes</a><input class="toctree-checkbox" id="toctree-checkbox-9" name="toctree-checkbox-9" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-9"><i class="fa-solid fa-chevron-down"></i></label><ul> <li class="toctree-l2"><a class="reference internal" href="contributing.html">Contributing to JAX</a></li> <li class="toctree-l2"><a class="reference internal" href="developer.html">Building from source</a></li> <li class="toctree-l2"><a class="reference internal" href="investigating_a_regression.html">Investigating a regression</a></li> <li class="toctree-l2"><a class="reference internal" href="autodidax.html">Autodidax: JAX core from scratch</a></li> <li class="toctree-l2"><a class="reference internal" href="autodidax2_part1.html">Autodidax2, part 1: JAX from scratch, again</a></li> <li class="toctree-l2 has-children"><a class="reference internal" href="jep/index.html">JAX Enhancement Proposals (JEPs)</a><input class="toctree-checkbox" id="toctree-checkbox-10" name="toctree-checkbox-10" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-10"><i class="fa-solid fa-chevron-down"></i></label><ul> <li class="toctree-l3"><a class="reference internal" href="jep/263-prng.html">263: JAX PRNG Design</a></li> <li class="toctree-l3"><a class="reference internal" href="jep/2026-custom-derivatives.html">2026: Custom JVP/VJP rules for JAX-transformable functions</a></li> <li class="toctree-l3"><a class="reference internal" href="jep/4008-custom-vjp-update.html">4008: Custom VJP and `nondiff_argnums` update</a></li> <li class="toctree-l3"><a class="reference internal" href="jep/4410-omnistaging.html">4410: Omnistaging</a></li> <li class="toctree-l3"><a class="reference internal" href="jep/9263-typed-keys.html">9263: Typed keys & pluggable RNGs</a></li> <li class="toctree-l3"><a class="reference internal" href="jep/9407-type-promotion.html">9407: Design of Type Promotion Semantics for JAX</a></li> <li class="toctree-l3"><a class="reference internal" href="jep/9419-jax-versioning.html">9419: Jax and Jaxlib versioning</a></li> <li class="toctree-l3"><a class="reference internal" href="jep/10657-sequencing-effects.html">10657: Sequencing side-effects in JAX</a></li> <li class="toctree-l3"><a class="reference internal" href="jep/11830-new-remat-checkpoint.html">11830: `jax.remat` / `jax.checkpoint` new implementation</a></li> <li class="toctree-l3"><a class="reference internal" href="jep/12049-type-annotations.html">12049: Type Annotation Roadmap for JAX</a></li> <li class="toctree-l3"><a class="reference internal" href="jep/14273-shard-map.html">14273: `shard_map` (`shmap`) for simple per-device code</a></li> <li class="toctree-l3"><a class="reference internal" href="jep/15856-jex.html">15856: `jax.extend`, an extensions module</a></li> <li class="toctree-l3"><a class="reference internal" href="jep/17111-shmap-transpose.html">17111: Efficient transposition of `shard_map` (and other maps)</a></li> <li class="toctree-l3"><a class="reference internal" href="jep/18137-numpy-scipy-scope.html">18137: Scope of JAX NumPy & SciPy Wrappers</a></li> <li class="toctree-l3"><a class="reference internal" href="jep/25516-effver.html">25516: Effort-based versioning</a></li> </ul> </li> </ul> </li> <li class="toctree-l1 has-children"><a class="reference internal" href="extensions.html">Extension guides</a><input class="toctree-checkbox" id="toctree-checkbox-11" name="toctree-checkbox-11" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-11"><i class="fa-solid fa-chevron-down"></i></label><ul> <li class="toctree-l2"><a class="reference internal" href="notebooks/Writing_custom_interpreters_in_Jax.html">Writing custom Jaxpr interpreters in JAX</a></li> <li class="toctree-l2 has-children"><a class="reference internal" href="jax.extend.html"><code class="docutils literal notranslate"><span class="pre">jax.extend</span></code> module</a><input class="toctree-checkbox" id="toctree-checkbox-12" name="toctree-checkbox-12" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-12"><i class="fa-solid fa-chevron-down"></i></label><ul> <li class="toctree-l3"><a class="reference internal" href="jax.extend.core.html"><code class="docutils literal notranslate"><span class="pre">jax.extend.core</span></code> module</a></li> <li class="toctree-l3"><a class="reference internal" href="jax.extend.linear_util.html"><code class="docutils literal notranslate"><span class="pre">jax.extend.linear_util</span></code> module</a></li> <li class="toctree-l3"><a class="reference internal" href="jax.extend.mlir.html"><code class="docutils literal notranslate"><span class="pre">jax.extend.mlir</span></code> module</a></li> <li class="toctree-l3"><a class="reference internal" href="jax.extend.random.html"><code class="docutils literal notranslate"><span class="pre">jax.extend.random</span></code> module</a></li> </ul> </li> <li class="toctree-l2"><a class="reference internal" href="building_on_jax.html">Building on JAX</a></li> </ul> </li> <li class="toctree-l1 has-children"><a class="reference internal" href="notes.html">Notes</a><input class="toctree-checkbox" id="toctree-checkbox-13" name="toctree-checkbox-13" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-13"><i class="fa-solid fa-chevron-down"></i></label><ul> <li class="toctree-l2"><a class="reference internal" href="api_compatibility.html">API compatibility</a></li> <li class="toctree-l2"><a class="reference internal" href="deprecation.html">Python and NumPy version support policy</a></li> <li class="toctree-l2"><a class="reference internal" href="jax_array_migration.html">jax.Array migration</a></li> <li class="toctree-l2"><a class="reference internal" href="async_dispatch.html">Asynchronous dispatch</a></li> <li class="toctree-l2"><a class="reference internal" href="concurrency.html">Concurrency</a></li> <li class="toctree-l2"><a class="reference internal" href="gpu_memory_allocation.html">GPU memory allocation</a></li> <li class="toctree-l2"><a class="reference internal" href="rank_promotion_warning.html">Rank promotion warning</a></li> </ul> </li> <li class="toctree-l1 current active has-children"><a class="reference internal" href="jax.html">Public API: <code class="docutils literal notranslate"><span class="pre">jax</span></code> package</a><input checked="" class="toctree-checkbox" id="toctree-checkbox-14" name="toctree-checkbox-14" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-14"><i class="fa-solid fa-chevron-down"></i></label><ul class="current"> <li class="toctree-l2 has-children"><a class="reference internal" href="jax.numpy.html"><code class="docutils literal notranslate"><span class="pre">jax.numpy</span></code> module</a><input class="toctree-checkbox" id="toctree-checkbox-15" name="toctree-checkbox-15" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-15"><i class="fa-solid fa-chevron-down"></i></label><ul> <li class="toctree-l3"><a class="reference internal" href="_autosummary/jax.numpy.fft.fft.html">jax.numpy.fft.fft</a></li> <li class="toctree-l3"><a class="reference internal" href="_autosummary/jax.numpy.fft.fft2.html">jax.numpy.fft.fft2</a></li> <li class="toctree-l3"><a class="reference internal" href="_autosummary/jax.numpy.fft.fftfreq.html">jax.numpy.fft.fftfreq</a></li> <li class="toctree-l3"><a class="reference internal" href="_autosummary/jax.numpy.fft.fftn.html">jax.numpy.fft.fftn</a></li> <li class="toctree-l3"><a class="reference internal" href="_autosummary/jax.numpy.fft.fftshift.html">jax.numpy.fft.fftshift</a></li> <li class="toctree-l3"><a class="reference internal" href="_autosummary/jax.numpy.fft.hfft.html">jax.numpy.fft.hfft</a></li> <li class="toctree-l3"><a class="reference internal" href="_autosummary/jax.numpy.fft.ifft.html">jax.numpy.fft.ifft</a></li> <li class="toctree-l3"><a class="reference internal" href="_autosummary/jax.numpy.fft.ifft2.html">jax.numpy.fft.ifft2</a></li> <li class="toctree-l3"><a class="reference internal" href="_autosummary/jax.numpy.fft.ifftn.html">jax.numpy.fft.ifftn</a></li> <li class="toctree-l3"><a class="reference internal" href="_autosummary/jax.numpy.fft.ifftshift.html">jax.numpy.fft.ifftshift</a></li> <li class="toctree-l3"><a class="reference internal" href="_autosummary/jax.numpy.fft.ihfft.html">jax.numpy.fft.ihfft</a></li> <li class="toctree-l3"><a class="reference internal" href="_autosummary/jax.numpy.fft.irfft.html">jax.numpy.fft.irfft</a></li> <li class="toctree-l3"><a class="reference internal" href="_autosummary/jax.numpy.fft.irfft2.html">jax.numpy.fft.irfft2</a></li> <li class="toctree-l3"><a class="reference internal" href="_autosummary/jax.numpy.fft.irfftn.html">jax.numpy.fft.irfftn</a></li> <li class="toctree-l3"><a class="reference internal" href="_autosummary/jax.numpy.fft.rfft.html">jax.numpy.fft.rfft</a></li> <li class="toctree-l3"><a class="reference internal" href="_autosummary/jax.numpy.fft.rfft2.html">jax.numpy.fft.rfft2</a></li> <li class="toctree-l3"><a class="reference internal" href="_autosummary/jax.numpy.fft.rfftfreq.html">jax.numpy.fft.rfftfreq</a></li> <li class="toctree-l3"><a class="reference internal" href="_autosummary/jax.numpy.fft.rfftn.html">jax.numpy.fft.rfftn</a></li> </ul> </li> <li class="toctree-l2 has-children"><a class="reference internal" href="jax.scipy.html"><code class="docutils literal notranslate"><span class="pre">jax.scipy</span></code> module</a><input class="toctree-checkbox" id="toctree-checkbox-16" name="toctree-checkbox-16" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-16"><i class="fa-solid fa-chevron-down"></i></label><ul> <li class="toctree-l3"><a class="reference internal" href="_autosummary/jax.scipy.stats.bernoulli.logpmf.html">jax.scipy.stats.bernoulli.logpmf</a></li> <li class="toctree-l3"><a class="reference internal" href="_autosummary/jax.scipy.stats.bernoulli.pmf.html">jax.scipy.stats.bernoulli.pmf</a></li> <li class="toctree-l3"><a class="reference internal" href="_autosummary/jax.scipy.stats.bernoulli.cdf.html">jax.scipy.stats.bernoulli.cdf</a></li> <li class="toctree-l3"><a class="reference internal" href="_autosummary/jax.scipy.stats.bernoulli.ppf.html">jax.scipy.stats.bernoulli.ppf</a></li> </ul> </li> <li class="toctree-l2"><a class="reference internal" href="jax.lax.html"><code class="docutils literal notranslate"><span class="pre">jax.lax</span></code> module</a></li> <li class="toctree-l2"><a class="reference internal" href="jax.random.html"><code class="docutils literal notranslate"><span class="pre">jax.random</span></code> module</a></li> <li class="toctree-l2 current active"><a class="current reference internal" href="#"><code class="docutils literal notranslate"><span class="pre">jax.sharding</span></code> module</a></li> <li class="toctree-l2"><a class="reference internal" href="jax.debug.html"><code class="docutils literal notranslate"><span class="pre">jax.debug</span></code> module</a></li> <li class="toctree-l2"><a class="reference internal" href="jax.dlpack.html"><code class="docutils literal notranslate"><span class="pre">jax.dlpack</span></code> module</a></li> <li class="toctree-l2"><a class="reference internal" href="jax.distributed.html"><code class="docutils literal notranslate"><span class="pre">jax.distributed</span></code> module</a></li> <li class="toctree-l2"><a class="reference internal" href="jax.dtypes.html"><code class="docutils literal notranslate"><span class="pre">jax.dtypes</span></code> module</a></li> <li class="toctree-l2"><a class="reference internal" href="jax.ffi.html"><code class="docutils literal notranslate"><span class="pre">jax.ffi</span></code> module</a></li> <li class="toctree-l2"><a class="reference internal" href="jax.flatten_util.html"><code class="docutils literal notranslate"><span class="pre">jax.flatten_util</span></code> module</a></li> <li class="toctree-l2"><a class="reference internal" href="jax.image.html"><code class="docutils literal notranslate"><span class="pre">jax.image</span></code> module</a></li> <li class="toctree-l2 has-children"><a class="reference internal" href="jax.nn.html"><code class="docutils literal notranslate"><span class="pre">jax.nn</span></code> module</a><input class="toctree-checkbox" id="toctree-checkbox-17" name="toctree-checkbox-17" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-17"><i class="fa-solid fa-chevron-down"></i></label><ul> <li class="toctree-l3"><a class="reference internal" href="jax.nn.initializers.html"><code class="docutils literal notranslate"><span class="pre">jax.nn.initializers</span></code> module</a></li> </ul> </li> <li class="toctree-l2"><a class="reference internal" href="jax.ops.html"><code class="docutils literal notranslate"><span class="pre">jax.ops</span></code> module</a></li> <li class="toctree-l2"><a class="reference internal" href="jax.profiler.html"><code class="docutils literal notranslate"><span class="pre">jax.profiler</span></code> module</a></li> <li class="toctree-l2"><a class="reference internal" href="jax.stages.html"><code class="docutils literal notranslate"><span class="pre">jax.stages</span></code> module</a></li> <li class="toctree-l2"><a class="reference internal" href="jax.test_util.html"><code class="docutils literal notranslate"><span class="pre">jax.test_util</span></code> module</a></li> <li class="toctree-l2"><a class="reference internal" href="jax.tree.html"><code class="docutils literal notranslate"><span class="pre">jax.tree</span></code> module</a></li> <li class="toctree-l2"><a class="reference internal" href="jax.tree_util.html"><code class="docutils literal notranslate"><span class="pre">jax.tree_util</span></code> module</a></li> <li class="toctree-l2"><a class="reference internal" href="jax.typing.html"><code class="docutils literal notranslate"><span class="pre">jax.typing</span></code> module</a></li> <li class="toctree-l2"><a class="reference internal" href="jax.export.html"><code class="docutils literal notranslate"><span class="pre">jax.export</span></code> module</a></li> <li class="toctree-l2 has-children"><a class="reference internal" href="jax.extend.html"><code class="docutils literal notranslate"><span class="pre">jax.extend</span></code> module</a><input class="toctree-checkbox" id="toctree-checkbox-18" name="toctree-checkbox-18" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-18"><i class="fa-solid fa-chevron-down"></i></label><ul> <li class="toctree-l3"><a class="reference internal" href="jax.extend.core.html"><code class="docutils literal notranslate"><span class="pre">jax.extend.core</span></code> module</a></li> <li class="toctree-l3"><a class="reference internal" href="jax.extend.linear_util.html"><code class="docutils literal notranslate"><span class="pre">jax.extend.linear_util</span></code> module</a></li> <li class="toctree-l3"><a class="reference internal" href="jax.extend.mlir.html"><code class="docutils literal notranslate"><span class="pre">jax.extend.mlir</span></code> module</a></li> <li class="toctree-l3"><a class="reference internal" href="jax.extend.random.html"><code class="docutils literal notranslate"><span class="pre">jax.extend.random</span></code> module</a></li> </ul> </li> <li class="toctree-l2 has-children"><a class="reference internal" href="jax.example_libraries.html"><code class="docutils literal notranslate"><span class="pre">jax.example_libraries</span></code> module</a><input class="toctree-checkbox" id="toctree-checkbox-19" name="toctree-checkbox-19" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-19"><i class="fa-solid fa-chevron-down"></i></label><ul> <li class="toctree-l3"><a class="reference internal" href="jax.example_libraries.optimizers.html"><code class="docutils literal notranslate"><span class="pre">jax.example_libraries.optimizers</span></code> module</a></li> <li class="toctree-l3"><a class="reference internal" href="jax.example_libraries.stax.html"><code class="docutils literal notranslate"><span class="pre">jax.example_libraries.stax</span></code> module</a></li> </ul> </li> <li class="toctree-l2 has-children"><a class="reference internal" href="jax.experimental.html"><code class="docutils literal notranslate"><span class="pre">jax.experimental</span></code> module</a><input class="toctree-checkbox" id="toctree-checkbox-20" name="toctree-checkbox-20" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-20"><i class="fa-solid fa-chevron-down"></i></label><ul> <li class="toctree-l3"><a class="reference internal" href="jax.experimental.checkify.html"><code class="docutils literal notranslate"><span class="pre">jax.experimental.checkify</span></code> module</a></li> <li class="toctree-l3"><a class="reference internal" href="jax.experimental.compilation_cache.html"><code class="docutils literal notranslate"><span class="pre">jax.experimental.compilation_cache</span></code> module</a></li> <li class="toctree-l3"><a class="reference internal" href="jax.experimental.custom_dce.html"><code class="docutils literal notranslate"><span class="pre">jax.experimental.custom_dce</span></code> module</a></li> <li class="toctree-l3"><a class="reference internal" href="jax.experimental.custom_partitioning.html"><code class="docutils literal notranslate"><span class="pre">jax.experimental.custom_partitioning</span></code> module</a></li> <li class="toctree-l3"><a class="reference internal" href="jax.experimental.jet.html"><code class="docutils literal notranslate"><span class="pre">jax.experimental.jet</span></code> module</a></li> <li class="toctree-l3"><a class="reference internal" href="jax.experimental.key_reuse.html"><code class="docutils literal notranslate"><span class="pre">jax.experimental.key_reuse</span></code> module</a></li> <li class="toctree-l3"><a class="reference internal" href="jax.experimental.mesh_utils.html"><code class="docutils literal notranslate"><span class="pre">jax.experimental.mesh_utils</span></code> module</a></li> <li class="toctree-l3"><a class="reference internal" href="jax.experimental.multihost_utils.html"><code class="docutils literal notranslate"><span class="pre">jax.experimental.multihost_utils</span></code> module</a></li> <li class="toctree-l3 has-children"><a class="reference internal" href="jax.experimental.pallas.html"><code class="docutils literal notranslate"><span class="pre">jax.experimental.pallas</span></code> module</a><input class="toctree-checkbox" id="toctree-checkbox-21" name="toctree-checkbox-21" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-21"><i class="fa-solid fa-chevron-down"></i></label><ul> <li class="toctree-l4"><a class="reference internal" href="jax.experimental.pallas.mosaic_gpu.html"><code class="docutils literal notranslate"><span class="pre">jax.experimental.pallas.mosaic_gpu</span></code> module</a></li> <li class="toctree-l4"><a class="reference internal" href="jax.experimental.pallas.triton.html"><code class="docutils literal notranslate"><span class="pre">jax.experimental.pallas.triton</span></code> module</a></li> <li class="toctree-l4"><a class="reference internal" href="jax.experimental.pallas.tpu.html"><code class="docutils literal notranslate"><span class="pre">jax.experimental.pallas.tpu</span></code> module</a></li> </ul> </li> <li class="toctree-l3"><a class="reference internal" href="jax.experimental.pjit.html"><code class="docutils literal notranslate"><span class="pre">jax.experimental.pjit</span></code> module</a></li> <li class="toctree-l3"><a class="reference internal" href="jax.experimental.serialize_executable.html"><code class="docutils literal notranslate"><span class="pre">jax.experimental.serialize_executable</span></code> module</a></li> <li class="toctree-l3"><a class="reference internal" href="jax.experimental.shard_map.html"><code class="docutils literal notranslate"><span class="pre">jax.experimental.shard_map</span></code> module</a></li> <li class="toctree-l3 has-children"><a class="reference internal" href="jax.experimental.sparse.html"><code class="docutils literal notranslate"><span class="pre">jax.experimental.sparse</span></code> module</a><input class="toctree-checkbox" id="toctree-checkbox-22" name="toctree-checkbox-22" type="checkbox"/><label class="toctree-toggle" for="toctree-checkbox-22"><i class="fa-solid fa-chevron-down"></i></label><ul> <li class="toctree-l4"><a class="reference internal" href="_autosummary/jax.experimental.sparse.BCOO.html">jax.experimental.sparse.BCOO</a></li> <li class="toctree-l4"><a class="reference internal" href="_autosummary/jax.experimental.sparse.bcoo_broadcast_in_dim.html">jax.experimental.sparse.bcoo_broadcast_in_dim</a></li> <li class="toctree-l4"><a class="reference internal" href="_autosummary/jax.experimental.sparse.bcoo_concatenate.html">jax.experimental.sparse.bcoo_concatenate</a></li> <li class="toctree-l4"><a class="reference internal" href="_autosummary/jax.experimental.sparse.bcoo_dot_general.html">jax.experimental.sparse.bcoo_dot_general</a></li> <li class="toctree-l4"><a class="reference internal" href="_autosummary/jax.experimental.sparse.bcoo_dot_general_sampled.html">jax.experimental.sparse.bcoo_dot_general_sampled</a></li> <li class="toctree-l4"><a class="reference internal" href="_autosummary/jax.experimental.sparse.bcoo_dynamic_slice.html">jax.experimental.sparse.bcoo_dynamic_slice</a></li> <li class="toctree-l4"><a class="reference internal" href="_autosummary/jax.experimental.sparse.bcoo_extract.html">jax.experimental.sparse.bcoo_extract</a></li> <li class="toctree-l4"><a class="reference internal" href="_autosummary/jax.experimental.sparse.bcoo_fromdense.html">jax.experimental.sparse.bcoo_fromdense</a></li> <li class="toctree-l4"><a class="reference internal" href="_autosummary/jax.experimental.sparse.bcoo_gather.html">jax.experimental.sparse.bcoo_gather</a></li> <li class="toctree-l4"><a class="reference internal" href="_autosummary/jax.experimental.sparse.bcoo_multiply_dense.html">jax.experimental.sparse.bcoo_multiply_dense</a></li> <li class="toctree-l4"><a class="reference internal" href="_autosummary/jax.experimental.sparse.bcoo_multiply_sparse.html">jax.experimental.sparse.bcoo_multiply_sparse</a></li> <li class="toctree-l4"><a class="reference internal" href="_autosummary/jax.experimental.sparse.bcoo_update_layout.html">jax.experimental.sparse.bcoo_update_layout</a></li> <li class="toctree-l4"><a class="reference internal" href="_autosummary/jax.experimental.sparse.bcoo_reduce_sum.html">jax.experimental.sparse.bcoo_reduce_sum</a></li> <li class="toctree-l4"><a class="reference internal" href="_autosummary/jax.experimental.sparse.bcoo_reshape.html">jax.experimental.sparse.bcoo_reshape</a></li> <li class="toctree-l4"><a class="reference internal" href="_autosummary/jax.experimental.sparse.bcoo_slice.html">jax.experimental.sparse.bcoo_slice</a></li> <li class="toctree-l4"><a class="reference internal" href="_autosummary/jax.experimental.sparse.bcoo_sort_indices.html">jax.experimental.sparse.bcoo_sort_indices</a></li> <li class="toctree-l4"><a class="reference internal" href="_autosummary/jax.experimental.sparse.bcoo_squeeze.html">jax.experimental.sparse.bcoo_squeeze</a></li> <li class="toctree-l4"><a class="reference internal" href="_autosummary/jax.experimental.sparse.bcoo_sum_duplicates.html">jax.experimental.sparse.bcoo_sum_duplicates</a></li> <li class="toctree-l4"><a class="reference internal" href="_autosummary/jax.experimental.sparse.bcoo_todense.html">jax.experimental.sparse.bcoo_todense</a></li> <li class="toctree-l4"><a class="reference internal" href="_autosummary/jax.experimental.sparse.bcoo_transpose.html">jax.experimental.sparse.bcoo_transpose</a></li> </ul> </li> </ul> </li> <li class="toctree-l2"><a class="reference internal" href="jax.lib.html"><code class="docutils literal notranslate"><span class="pre">jax.lib</span></code> module</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.addressable_shards.html">jax.Array.addressable_shards</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.all.html">jax.Array.all</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.any.html">jax.Array.any</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.argmax.html">jax.Array.argmax</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.argmin.html">jax.Array.argmin</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.argpartition.html">jax.Array.argpartition</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.argsort.html">jax.Array.argsort</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.astype.html">jax.Array.astype</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.at.html">jax.Array.at</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.choose.html">jax.Array.choose</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.clip.html">jax.Array.clip</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.compress.html">jax.Array.compress</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.committed.html">jax.Array.committed</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.conj.html">jax.Array.conj</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.conjugate.html">jax.Array.conjugate</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.copy.html">jax.Array.copy</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.copy_to_host_async.html">jax.Array.copy_to_host_async</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.cumprod.html">jax.Array.cumprod</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.cumsum.html">jax.Array.cumsum</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.device.html">jax.Array.device</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.diagonal.html">jax.Array.diagonal</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.dot.html">jax.Array.dot</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.dtype.html">jax.Array.dtype</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.flat.html">jax.Array.flat</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.flatten.html">jax.Array.flatten</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.global_shards.html">jax.Array.global_shards</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.imag.html">jax.Array.imag</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.is_fully_addressable.html">jax.Array.is_fully_addressable</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.is_fully_replicated.html">jax.Array.is_fully_replicated</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.item.html">jax.Array.item</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.itemsize.html">jax.Array.itemsize</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.max.html">jax.Array.max</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.mean.html">jax.Array.mean</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.min.html">jax.Array.min</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.nbytes.html">jax.Array.nbytes</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.ndim.html">jax.Array.ndim</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.nonzero.html">jax.Array.nonzero</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.prod.html">jax.Array.prod</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.ptp.html">jax.Array.ptp</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.ravel.html">jax.Array.ravel</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.real.html">jax.Array.real</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.repeat.html">jax.Array.repeat</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.reshape.html">jax.Array.reshape</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.round.html">jax.Array.round</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.searchsorted.html">jax.Array.searchsorted</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.shape.html">jax.Array.shape</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.sharding.html">jax.Array.sharding</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.size.html">jax.Array.size</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.sort.html">jax.Array.sort</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.squeeze.html">jax.Array.squeeze</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.std.html">jax.Array.std</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.sum.html">jax.Array.sum</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.swapaxes.html">jax.Array.swapaxes</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.take.html">jax.Array.take</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.to_device.html">jax.Array.to_device</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.trace.html">jax.Array.trace</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.transpose.html">jax.Array.transpose</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.var.html">jax.Array.var</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.view.html">jax.Array.view</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.T.html">jax.Array.T</a></li> <li class="toctree-l2"><a class="reference internal" href="_autosummary/jax.Array.mT.html">jax.Array.mT</a></li> </ul> </li> <li class="toctree-l1"><a class="reference internal" href="about.html">About the project</a></li> </ul> <ul class="nav bd-sidenav"> <li class="toctree-l1"><a class="reference internal" href="changelog.html">Change log</a></li> <li class="toctree-l1"><a class="reference internal" href="glossary.html">Glossary of terms</a></li> </ul> </div> </nav></div> </div> <div class="sidebar-primary-items__end sidebar-primary__section"> </div> <div id="rtd-footer-container"></div> </div> <main id="main-content" class="bd-main"> <div class="sbt-scroll-pixel-helper"></div> <div class="bd-content"> <div class="bd-article-container"> <div class="bd-header-article"> <div class="header-article-items header-article__inner"> <div class="header-article-items__start"> <div class="header-article-item"><label class="sidebar-toggle primary-toggle btn btn-sm" for="__primary" title="Toggle primary sidebar" data-bs-placement="bottom" data-bs-toggle="tooltip"> <span class="fa-solid fa-bars"></span> </label></div> </div> <div class="header-article-items__end"> <div class="header-article-item"> <div class="article-header-buttons"> <a href="https://github.com/jax-ml/jax" target="_blank" class="btn btn-sm btn-source-repository-button" title="Source repository" data-bs-placement="bottom" data-bs-toggle="tooltip" > <span class="btn__icon-container"> <i class="fab fa-github"></i> </span> </a> <div class="dropdown dropdown-download-buttons"> <button class="btn dropdown-toggle" type="button" data-bs-toggle="dropdown" aria-expanded="false" aria-label="Download this page"> <i class="fas fa-download"></i> </button> <ul class="dropdown-menu"> <li><a href="_sources/jax.sharding.rst" target="_blank" class="btn btn-sm btn-download-source-button dropdown-item" title="Download source file" data-bs-placement="left" data-bs-toggle="tooltip" > <span class="btn__icon-container"> <i class="fas fa-file"></i> </span> <span class="btn__text-container">.rst</span> </a> </li> <li> <button onclick="window.print()" class="btn btn-sm btn-download-pdf-button dropdown-item" title="Print to PDF" data-bs-placement="left" data-bs-toggle="tooltip" > <span class="btn__icon-container"> <i class="fas fa-file-pdf"></i> </span> <span class="btn__text-container">.pdf</span> </button> </li> </ul> </div> <button onclick="toggleFullScreen()" class="btn btn-sm btn-fullscreen-button" title="Fullscreen mode" data-bs-placement="bottom" data-bs-toggle="tooltip" > <span class="btn__icon-container"> <i class="fas fa-expand"></i> </span> </button> <script> document.write(` <button class="btn btn-sm navbar-btn theme-switch-button" title="light/dark" aria-label="light/dark" data-bs-placement="bottom" data-bs-toggle="tooltip"> <span class="theme-switch nav-link" data-mode="light"><i class="fa-solid fa-sun fa-lg"></i></span> <span class="theme-switch nav-link" data-mode="dark"><i class="fa-solid fa-moon fa-lg"></i></span> <span class="theme-switch nav-link" data-mode="auto"><i class="fa-solid fa-circle-half-stroke fa-lg"></i></span> </button> `); </script> <script> document.write(` <button class="btn btn-sm navbar-btn search-button search-button__button" title="Search" aria-label="Search" data-bs-placement="bottom" data-bs-toggle="tooltip"> <i class="fa-solid fa-magnifying-glass fa-lg"></i> </button> `); </script> <label class="sidebar-toggle secondary-toggle btn btn-sm" for="__secondary"title="Toggle secondary sidebar" data-bs-placement="bottom" data-bs-toggle="tooltip"> <span class="fa-solid fa-list"></span> </label> </div></div> </div> </div> </div> <div id="jb-print-docs-body" class="onlyprint"> <h1>jax.sharding module</h1> <!-- Table of contents --> <div id="print-main-content"> <div id="jb-print-toc"> <div> <h2> Contents </h2> </div> <nav aria-label="Page"> <ul class="visible nav section-nav flex-column"> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#classes">Classes</a><ul class="visible nav section-nav flex-column"> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.Sharding"><code class="docutils literal notranslate"><span class="pre">Sharding</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.Sharding.addressable_devices"><code class="docutils literal notranslate"><span class="pre">Sharding.addressable_devices</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.Sharding.addressable_devices_indices_map"><code class="docutils literal notranslate"><span class="pre">Sharding.addressable_devices_indices_map()</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.Sharding.device_set"><code class="docutils literal notranslate"><span class="pre">Sharding.device_set</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.Sharding.devices_indices_map"><code class="docutils literal notranslate"><span class="pre">Sharding.devices_indices_map()</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.Sharding.is_equivalent_to"><code class="docutils literal notranslate"><span class="pre">Sharding.is_equivalent_to()</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.Sharding.is_fully_addressable"><code class="docutils literal notranslate"><span class="pre">Sharding.is_fully_addressable</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.Sharding.is_fully_replicated"><code class="docutils literal notranslate"><span class="pre">Sharding.is_fully_replicated</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.Sharding.memory_kind"><code class="docutils literal notranslate"><span class="pre">Sharding.memory_kind</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.Sharding.num_devices"><code class="docutils literal notranslate"><span class="pre">Sharding.num_devices</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.Sharding.shard_shape"><code class="docutils literal notranslate"><span class="pre">Sharding.shard_shape()</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.Sharding.with_memory_kind"><code class="docutils literal notranslate"><span class="pre">Sharding.with_memory_kind()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.SingleDeviceSharding"><code class="docutils literal notranslate"><span class="pre">SingleDeviceSharding</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.SingleDeviceSharding.device_set"><code class="docutils literal notranslate"><span class="pre">SingleDeviceSharding.device_set</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.SingleDeviceSharding.devices_indices_map"><code class="docutils literal notranslate"><span class="pre">SingleDeviceSharding.devices_indices_map()</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.SingleDeviceSharding.is_fully_addressable"><code class="docutils literal notranslate"><span class="pre">SingleDeviceSharding.is_fully_addressable</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.SingleDeviceSharding.is_fully_replicated"><code class="docutils literal notranslate"><span class="pre">SingleDeviceSharding.is_fully_replicated</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.SingleDeviceSharding.memory_kind"><code class="docutils literal notranslate"><span class="pre">SingleDeviceSharding.memory_kind</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.SingleDeviceSharding.num_devices"><code class="docutils literal notranslate"><span class="pre">SingleDeviceSharding.num_devices</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.SingleDeviceSharding.with_memory_kind"><code class="docutils literal notranslate"><span class="pre">SingleDeviceSharding.with_memory_kind()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.NamedSharding"><code class="docutils literal notranslate"><span class="pre">NamedSharding</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.NamedSharding.addressable_devices"><code class="docutils literal notranslate"><span class="pre">NamedSharding.addressable_devices</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.NamedSharding.device_set"><code class="docutils literal notranslate"><span class="pre">NamedSharding.device_set</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.NamedSharding.is_fully_addressable"><code class="docutils literal notranslate"><span class="pre">NamedSharding.is_fully_addressable</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.NamedSharding.is_fully_replicated"><code class="docutils literal notranslate"><span class="pre">NamedSharding.is_fully_replicated</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.NamedSharding.memory_kind"><code class="docutils literal notranslate"><span class="pre">NamedSharding.memory_kind</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.NamedSharding.mesh"><code class="docutils literal notranslate"><span class="pre">NamedSharding.mesh</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.NamedSharding.num_devices"><code class="docutils literal notranslate"><span class="pre">NamedSharding.num_devices</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.NamedSharding.spec"><code class="docutils literal notranslate"><span class="pre">NamedSharding.spec</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.NamedSharding.with_memory_kind"><code class="docutils literal notranslate"><span class="pre">NamedSharding.with_memory_kind()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.PositionalSharding"><code class="docutils literal notranslate"><span class="pre">PositionalSharding</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.PositionalSharding.device_set"><code class="docutils literal notranslate"><span class="pre">PositionalSharding.device_set</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.PositionalSharding.is_fully_addressable"><code class="docutils literal notranslate"><span class="pre">PositionalSharding.is_fully_addressable</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.PositionalSharding.is_fully_replicated"><code class="docutils literal notranslate"><span class="pre">PositionalSharding.is_fully_replicated</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.PositionalSharding.memory_kind"><code class="docutils literal notranslate"><span class="pre">PositionalSharding.memory_kind</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.PositionalSharding.num_devices"><code class="docutils literal notranslate"><span class="pre">PositionalSharding.num_devices</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.PositionalSharding.with_memory_kind"><code class="docutils literal notranslate"><span class="pre">PositionalSharding.with_memory_kind()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.PmapSharding"><code class="docutils literal notranslate"><span class="pre">PmapSharding</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.PmapSharding.default"><code class="docutils literal notranslate"><span class="pre">PmapSharding.default()</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.PmapSharding.device_set"><code class="docutils literal notranslate"><span class="pre">PmapSharding.device_set</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.PmapSharding.devices"><code class="docutils literal notranslate"><span class="pre">PmapSharding.devices</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.PmapSharding.devices_indices_map"><code class="docutils literal notranslate"><span class="pre">PmapSharding.devices_indices_map()</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.PmapSharding.is_equivalent_to"><code class="docutils literal notranslate"><span class="pre">PmapSharding.is_equivalent_to()</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.PmapSharding.is_fully_addressable"><code class="docutils literal notranslate"><span class="pre">PmapSharding.is_fully_addressable</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.PmapSharding.is_fully_replicated"><code class="docutils literal notranslate"><span class="pre">PmapSharding.is_fully_replicated</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.PmapSharding.memory_kind"><code class="docutils literal notranslate"><span class="pre">PmapSharding.memory_kind</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.PmapSharding.num_devices"><code class="docutils literal notranslate"><span class="pre">PmapSharding.num_devices</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.PmapSharding.shard_shape"><code class="docutils literal notranslate"><span class="pre">PmapSharding.shard_shape()</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.PmapSharding.sharding_spec"><code class="docutils literal notranslate"><span class="pre">PmapSharding.sharding_spec</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.PmapSharding.with_memory_kind"><code class="docutils literal notranslate"><span class="pre">PmapSharding.with_memory_kind()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.GSPMDSharding"><code class="docutils literal notranslate"><span class="pre">GSPMDSharding</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.GSPMDSharding.device_set"><code class="docutils literal notranslate"><span class="pre">GSPMDSharding.device_set</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.GSPMDSharding.is_fully_addressable"><code class="docutils literal notranslate"><span class="pre">GSPMDSharding.is_fully_addressable</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.GSPMDSharding.is_fully_replicated"><code class="docutils literal notranslate"><span class="pre">GSPMDSharding.is_fully_replicated</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.GSPMDSharding.memory_kind"><code class="docutils literal notranslate"><span class="pre">GSPMDSharding.memory_kind</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.GSPMDSharding.num_devices"><code class="docutils literal notranslate"><span class="pre">GSPMDSharding.num_devices</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.GSPMDSharding.with_memory_kind"><code class="docutils literal notranslate"><span class="pre">GSPMDSharding.with_memory_kind()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.PartitionSpec"><code class="docutils literal notranslate"><span class="pre">PartitionSpec</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.Mesh"><code class="docutils literal notranslate"><span class="pre">Mesh</span></code></a></li> </ul> </li> </ul> </nav> </div> </div> </div> <div id="searchbox"></div> <article class="bd-article" role="main"> <section id="module-jax.sharding"> <span id="jax-sharding-module"></span><h1><code class="docutils literal notranslate"><span class="pre">jax.sharding</span></code> module<a class="headerlink" href="#module-jax.sharding" title="Link to this heading">#</a></h1> <section id="classes"> <h2>Classes<a class="headerlink" href="#classes" title="Link to this heading">#</a></h2> <dl class="py class"> <dt class="sig sig-object py" id="jax.sharding.Sharding"> <em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">jax.sharding.</span></span><span class="sig-name descname"><span class="pre">Sharding</span></span><a class="headerlink" href="#jax.sharding.Sharding" title="Link to this definition">#</a></dt> <dd><p>Describes how a <a class="reference internal" href="_autosummary/jax.Array.html#jax.Array" title="jax.Array"><code class="xref py py-class docutils literal notranslate"><span class="pre">jax.Array</span></code></a> is laid out across devices.</p> <dl class="py property"> <dt class="sig sig-object py" id="jax.sharding.Sharding.addressable_devices"> <em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">addressable_devices</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#set" title="(in Python v3.13)"><span class="pre">set</span></a><span class="p"><span class="pre">[</span></span><a class="reference internal" href="_autosummary/jax.Device.html#jax.Device" title="jax.Device"><span class="pre">Device</span></a><span class="p"><span class="pre">]</span></span></em><a class="headerlink" href="#jax.sharding.Sharding.addressable_devices" title="Link to this definition">#</a></dt> <dd><p>The set of devices in the <a class="reference internal" href="#jax.sharding.Sharding" title="jax.sharding.Sharding"><code class="xref py py-class docutils literal notranslate"><span class="pre">Sharding</span></code></a> that are addressable by the current process.</p> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="jax.sharding.Sharding.addressable_devices_indices_map"> <span class="sig-name descname"><span class="pre">addressable_devices_indices_map</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">global_shape</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/jax-ml/jax/blob/main/jax/_src/sharding.py#L156-L164"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#jax.sharding.Sharding.addressable_devices_indices_map" title="Link to this definition">#</a></dt> <dd><p>A mapping from addressable devices to the slice of array data each contains.</p> <p><code class="docutils literal notranslate"><span class="pre">addressable_devices_indices_map</span></code> contains that part of <code class="docutils literal notranslate"><span class="pre">device_indices_map</span></code> that applies to the addressable devices.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters<span class="colon">:</span></dt> <dd class="field-odd"><p><strong>global_shape</strong> (<em>Shape</em>)</p> </dd> <dt class="field-even">Return type<span class="colon">:</span></dt> <dd class="field-even"><p>Mapping[<a class="reference internal" href="_autosummary/jax.Device.html#jax.Device" title="jax.Device">Device</a>, Index | None]</p> </dd> </dl> </dd></dl> <dl class="py property"> <dt class="sig sig-object py" id="jax.sharding.Sharding.device_set"> <em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">device_set</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#set" title="(in Python v3.13)"><span class="pre">set</span></a><span class="p"><span class="pre">[</span></span><a class="reference internal" href="_autosummary/jax.Device.html#jax.Device" title="jax.Device"><span class="pre">Device</span></a><span class="p"><span class="pre">]</span></span></em><a class="reference external" href="https://github.com/jax-ml/jax/blob/main/jax/_src/sharding.py#L86-L94"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#jax.sharding.Sharding.device_set" title="Link to this definition">#</a></dt> <dd><p>The set of devices that this <a class="reference internal" href="#jax.sharding.Sharding" title="jax.sharding.Sharding"><code class="xref py py-class docutils literal notranslate"><span class="pre">Sharding</span></code></a> spans.</p> <p>In multi-controller JAX, the set of devices is global, i.e., includes non-addressable devices from other processes.</p> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="jax.sharding.Sharding.devices_indices_map"> <span class="sig-name descname"><span class="pre">devices_indices_map</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">global_shape</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/jax-ml/jax/blob/main/jax/_src/sharding.py#L165-L172"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#jax.sharding.Sharding.devices_indices_map" title="Link to this definition">#</a></dt> <dd><p>Returns a mapping from devices to the array slices each contains.</p> <p>The mapping includes all global devices, i.e., including non-addressable devices from other processes.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters<span class="colon">:</span></dt> <dd class="field-odd"><p><strong>global_shape</strong> (<em>Shape</em>)</p> </dd> <dt class="field-even">Return type<span class="colon">:</span></dt> <dd class="field-even"><p>Mapping[<a class="reference internal" href="_autosummary/jax.Device.html#jax.Device" title="jax.Device">Device</a>, Index]</p> </dd> </dl> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="jax.sharding.Sharding.is_equivalent_to"> <span class="sig-name descname"><span class="pre">is_equivalent_to</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">other</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">ndim</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/jax-ml/jax/blob/main/jax/_src/sharding.py#L190-L210"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#jax.sharding.Sharding.is_equivalent_to" title="Link to this definition">#</a></dt> <dd><p>Returns <code class="docutils literal notranslate"><span class="pre">True</span></code> if two shardings are equivalent.</p> <p>Two shardings are equivalent if they place the same logical array shards on the same devices.</p> <p>For example, a <a class="reference internal" href="#jax.sharding.NamedSharding" title="jax.sharding.NamedSharding"><code class="xref py py-class docutils literal notranslate"><span class="pre">NamedSharding</span></code></a> may be equivalent to a <a class="reference internal" href="#jax.sharding.PositionalSharding" title="jax.sharding.PositionalSharding"><code class="xref py py-class docutils literal notranslate"><span class="pre">PositionalSharding</span></code></a> if both place the same shards of the array on the same devices.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters<span class="colon">:</span></dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>self</strong> (<a class="reference internal" href="#jax.sharding.Sharding" title="jax.sharding.Sharding"><em>Sharding</em></a>)</p></li> <li><p><strong>other</strong> (<a class="reference internal" href="#jax.sharding.Sharding" title="jax.sharding.Sharding"><em>Sharding</em></a>)</p></li> <li><p><strong>ndim</strong> (<a class="reference external" href="https://docs.python.org/3/library/functions.html#int" title="(in Python v3.13)"><em>int</em></a>)</p></li> </ul> </dd> <dt class="field-even">Return type<span class="colon">:</span></dt> <dd class="field-even"><p><a class="reference external" href="https://docs.python.org/3/library/functions.html#bool" title="(in Python v3.13)">bool</a></p> </dd> </dl> </dd></dl> <dl class="py property"> <dt class="sig sig-object py" id="jax.sharding.Sharding.is_fully_addressable"> <em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">is_fully_addressable</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/functions.html#bool" title="(in Python v3.13)"><span class="pre">bool</span></a></em><a class="reference external" href="https://github.com/jax-ml/jax/blob/main/jax/_src/sharding.py#L104-L113"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#jax.sharding.Sharding.is_fully_addressable" title="Link to this definition">#</a></dt> <dd><p>Is this sharding fully addressable?</p> <p>A sharding is fully addressable if the current process can address all of the devices named in the <a class="reference internal" href="#jax.sharding.Sharding" title="jax.sharding.Sharding"><code class="xref py py-class docutils literal notranslate"><span class="pre">Sharding</span></code></a>. <code class="docutils literal notranslate"><span class="pre">is_fully_addressable</span></code> is equivalent to “is_local” in multi-process JAX.</p> </dd></dl> <dl class="py property"> <dt class="sig sig-object py" id="jax.sharding.Sharding.is_fully_replicated"> <em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">is_fully_replicated</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/functions.html#bool" title="(in Python v3.13)"><span class="pre">bool</span></a></em><a class="reference external" href="https://github.com/jax-ml/jax/blob/main/jax/_src/sharding.py#L95-L103"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#jax.sharding.Sharding.is_fully_replicated" title="Link to this definition">#</a></dt> <dd><p>Is this sharding fully replicated?</p> <p>A sharding is fully replicated if each device has a complete copy of the entire data.</p> </dd></dl> <dl class="py property"> <dt class="sig sig-object py" id="jax.sharding.Sharding.memory_kind"> <em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">memory_kind</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.13)"><span class="pre">str</span></a><span class="w"> </span><span class="p"><span class="pre">|</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/constants.html#None" title="(in Python v3.13)"><span class="pre">None</span></a></em><a class="reference external" href="https://github.com/jax-ml/jax/blob/main/jax/_src/sharding.py#L119-L123"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#jax.sharding.Sharding.memory_kind" title="Link to this definition">#</a></dt> <dd><p>Returns the memory kind of the sharding.</p> </dd></dl> <dl class="py property"> <dt class="sig sig-object py" id="jax.sharding.Sharding.num_devices"> <em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">num_devices</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/functions.html#int" title="(in Python v3.13)"><span class="pre">int</span></a></em><a class="reference external" href="https://github.com/jax-ml/jax/blob/main/jax/_src/sharding.py#L114-L118"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#jax.sharding.Sharding.num_devices" title="Link to this definition">#</a></dt> <dd><p>Number of devices that the sharding contains.</p> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="jax.sharding.Sharding.shard_shape"> <span class="sig-name descname"><span class="pre">shard_shape</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">global_shape</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/jax-ml/jax/blob/main/jax/_src/sharding.py#L182-L189"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#jax.sharding.Sharding.shard_shape" title="Link to this definition">#</a></dt> <dd><p>Returns the shape of the data on each device.</p> <p>The shard shape returned by this function is calculated from <code class="docutils literal notranslate"><span class="pre">global_shape</span></code> and the properties of the sharding.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters<span class="colon">:</span></dt> <dd class="field-odd"><p><strong>global_shape</strong> (<em>Shape</em>)</p> </dd> <dt class="field-even">Return type<span class="colon">:</span></dt> <dd class="field-even"><p>Shape</p> </dd> </dl> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="jax.sharding.Sharding.with_memory_kind"> <span class="sig-name descname"><span class="pre">with_memory_kind</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">kind</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/jax-ml/jax/blob/main/jax/_src/sharding.py#L124-L127"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#jax.sharding.Sharding.with_memory_kind" title="Link to this definition">#</a></dt> <dd><p>Returns a new Sharding instance with the specified memory kind.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters<span class="colon">:</span></dt> <dd class="field-odd"><p><strong>kind</strong> (<a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.13)"><em>str</em></a>)</p> </dd> <dt class="field-even">Return type<span class="colon">:</span></dt> <dd class="field-even"><p><a class="reference internal" href="#jax.sharding.Sharding" title="jax.sharding.Sharding">Sharding</a></p> </dd> </dl> </dd></dl> </dd></dl> <dl class="py class"> <dt class="sig sig-object py" id="jax.sharding.SingleDeviceSharding"> <em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">jax.sharding.</span></span><span class="sig-name descname"><span class="pre">SingleDeviceSharding</span></span><a class="headerlink" href="#jax.sharding.SingleDeviceSharding" title="Link to this definition">#</a></dt> <dd><p>Bases: <a class="reference internal" href="#jax.sharding.Sharding" title="jaxlib.xla_extension.Sharding"><code class="xref py py-class docutils literal notranslate"><span class="pre">Sharding</span></code></a></p> <p>A <a class="reference internal" href="#jax.sharding.Sharding" title="jax.sharding.Sharding"><code class="xref py py-class docutils literal notranslate"><span class="pre">Sharding</span></code></a> that places its data on a single device.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters<span class="colon">:</span></dt> <dd class="field-odd"><p><strong>device</strong> – A single <code class="xref py py-class docutils literal notranslate"><span class="pre">Device</span></code>.</p> </dd> </dl> <p class="rubric">Examples</p> <div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="n">single_device_sharding</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">sharding</span><span class="o">.</span><span class="n">SingleDeviceSharding</span><span class="p">(</span> <span class="gp">... </span> <span class="n">jax</span><span class="o">.</span><span class="n">devices</span><span class="p">()[</span><span class="mi">0</span><span class="p">])</span> </pre></div> </div> <dl class="py property"> <dt class="sig sig-object py" id="jax.sharding.SingleDeviceSharding.device_set"> <em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">device_set</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#set" title="(in Python v3.13)"><span class="pre">set</span></a><span class="p"><span class="pre">[</span></span><a class="reference internal" href="_autosummary/jax.Device.html#jax.Device" title="jax.Device"><span class="pre">Device</span></a><span class="p"><span class="pre">]</span></span></em><a class="reference external" href="https://github.com/jax-ml/jax/blob/main/jax/_src/sharding_impls.py#L165-L168"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#jax.sharding.SingleDeviceSharding.device_set" title="Link to this definition">#</a></dt> <dd><p>The set of devices that this <a class="reference internal" href="#jax.sharding.Sharding" title="jax.sharding.Sharding"><code class="xref py py-class docutils literal notranslate"><span class="pre">Sharding</span></code></a> spans.</p> <p>In multi-controller JAX, the set of devices is global, i.e., includes non-addressable devices from other processes.</p> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="jax.sharding.SingleDeviceSharding.devices_indices_map"> <span class="sig-name descname"><span class="pre">devices_indices_map</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">global_shape</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/jax-ml/jax/blob/main/jax/_src/sharding_impls.py#L176-L178"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#jax.sharding.SingleDeviceSharding.devices_indices_map" title="Link to this definition">#</a></dt> <dd><p>Returns a mapping from devices to the array slices each contains.</p> <p>The mapping includes all global devices, i.e., including non-addressable devices from other processes.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters<span class="colon">:</span></dt> <dd class="field-odd"><p><strong>global_shape</strong> (<em>Shape</em>)</p> </dd> <dt class="field-even">Return type<span class="colon">:</span></dt> <dd class="field-even"><p>Mapping[<a class="reference internal" href="_autosummary/jax.Device.html#jax.Device" title="jax.Device">Device</a>, Index]</p> </dd> </dl> </dd></dl> <dl class="py property"> <dt class="sig sig-object py" id="jax.sharding.SingleDeviceSharding.is_fully_addressable"> <em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">is_fully_addressable</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/functions.html#bool" title="(in Python v3.13)"><span class="pre">bool</span></a></em><a class="reference external" href="https://github.com/jax-ml/jax/blob/main/jax/_src/sharding_impls.py#L195-L198"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#jax.sharding.SingleDeviceSharding.is_fully_addressable" title="Link to this definition">#</a></dt> <dd><p>Is this sharding fully addressable?</p> <p>A sharding is fully addressable if the current process can address all of the devices named in the <a class="reference internal" href="#jax.sharding.Sharding" title="jax.sharding.Sharding"><code class="xref py py-class docutils literal notranslate"><span class="pre">Sharding</span></code></a>. <code class="docutils literal notranslate"><span class="pre">is_fully_addressable</span></code> is equivalent to “is_local” in multi-process JAX.</p> </dd></dl> <dl class="py property"> <dt class="sig sig-object py" id="jax.sharding.SingleDeviceSharding.is_fully_replicated"> <em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">is_fully_replicated</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/functions.html#bool" title="(in Python v3.13)"><span class="pre">bool</span></a></em><a class="reference external" href="https://github.com/jax-ml/jax/blob/main/jax/_src/sharding_impls.py#L191-L194"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#jax.sharding.SingleDeviceSharding.is_fully_replicated" title="Link to this definition">#</a></dt> <dd><p>Is this sharding fully replicated?</p> <p>A sharding is fully replicated if each device has a complete copy of the entire data.</p> </dd></dl> <dl class="py property"> <dt class="sig sig-object py" id="jax.sharding.SingleDeviceSharding.memory_kind"> <em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">memory_kind</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.13)"><span class="pre">str</span></a><span class="w"> </span><span class="p"><span class="pre">|</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/constants.html#None" title="(in Python v3.13)"><span class="pre">None</span></a></em><a class="reference external" href="https://github.com/jax-ml/jax/blob/main/jax/_src/sharding_impls.py#L169-L172"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#jax.sharding.SingleDeviceSharding.memory_kind" title="Link to this definition">#</a></dt> <dd><p>Returns the memory kind of the sharding.</p> </dd></dl> <dl class="py property"> <dt class="sig sig-object py" id="jax.sharding.SingleDeviceSharding.num_devices"> <em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">num_devices</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/functions.html#int" title="(in Python v3.13)"><span class="pre">int</span></a></em><a class="reference external" href="https://github.com/jax-ml/jax/blob/main/jax/_src/sharding_impls.py#L161-L164"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#jax.sharding.SingleDeviceSharding.num_devices" title="Link to this definition">#</a></dt> <dd><p>Number of devices that the sharding contains.</p> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="jax.sharding.SingleDeviceSharding.with_memory_kind"> <span class="sig-name descname"><span class="pre">with_memory_kind</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">kind</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/jax-ml/jax/blob/main/jax/_src/sharding_impls.py#L173-L175"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#jax.sharding.SingleDeviceSharding.with_memory_kind" title="Link to this definition">#</a></dt> <dd><p>Returns a new Sharding instance with the specified memory kind.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters<span class="colon">:</span></dt> <dd class="field-odd"><p><strong>kind</strong> (<a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.13)"><em>str</em></a>)</p> </dd> <dt class="field-even">Return type<span class="colon">:</span></dt> <dd class="field-even"><p><a class="reference internal" href="#jax.sharding.SingleDeviceSharding" title="jax.sharding.SingleDeviceSharding">SingleDeviceSharding</a></p> </dd> </dl> </dd></dl> </dd></dl> <dl class="py class"> <dt class="sig sig-object py" id="jax.sharding.NamedSharding"> <em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">jax.sharding.</span></span><span class="sig-name descname"><span class="pre">NamedSharding</span></span><a class="headerlink" href="#jax.sharding.NamedSharding" title="Link to this definition">#</a></dt> <dd><p>Bases: <a class="reference internal" href="#jax.sharding.Sharding" title="jaxlib.xla_extension.Sharding"><code class="xref py py-class docutils literal notranslate"><span class="pre">Sharding</span></code></a></p> <p>A <a class="reference internal" href="#jax.sharding.NamedSharding" title="jax.sharding.NamedSharding"><code class="xref py py-class docutils literal notranslate"><span class="pre">NamedSharding</span></code></a> expresses sharding using named axes.</p> <p>A <a class="reference internal" href="#jax.sharding.NamedSharding" title="jax.sharding.NamedSharding"><code class="xref py py-class docutils literal notranslate"><span class="pre">NamedSharding</span></code></a> is a pair of a <a class="reference internal" href="#jax.sharding.Mesh" title="jax.sharding.Mesh"><code class="xref py py-class docutils literal notranslate"><span class="pre">Mesh</span></code></a> of devices and <a class="reference internal" href="#jax.sharding.PartitionSpec" title="jax.sharding.PartitionSpec"><code class="xref py py-class docutils literal notranslate"><span class="pre">PartitionSpec</span></code></a> which describes how to shard an array across that mesh.</p> <p>A <a class="reference internal" href="#jax.sharding.Mesh" title="jax.sharding.Mesh"><code class="xref py py-class docutils literal notranslate"><span class="pre">Mesh</span></code></a> is a multidimensional NumPy array of JAX devices, where each axis of the mesh has a name, e.g. <code class="docutils literal notranslate"><span class="pre">'x'</span></code> or <code class="docutils literal notranslate"><span class="pre">'y'</span></code>.</p> <p>A <a class="reference internal" href="#jax.sharding.PartitionSpec" title="jax.sharding.PartitionSpec"><code class="xref py py-class docutils literal notranslate"><span class="pre">PartitionSpec</span></code></a> is a tuple, whose elements can be a <code class="docutils literal notranslate"><span class="pre">None</span></code>, a mesh axis, or a tuple of mesh axes. Each element describes how an input dimension is partitioned across zero or more mesh dimensions. For example, <code class="docutils literal notranslate"><span class="pre">PartitionSpec('x',</span> <span class="pre">'y')</span></code> says that the first dimension of data is sharded across <code class="docutils literal notranslate"><span class="pre">x</span></code> axis of the mesh, and the second dimension is sharded across <code class="docutils literal notranslate"><span class="pre">y</span></code> axis of the mesh.</p> <p>The Distributed arrays and automatic parallelization (<a class="reference external" href="https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#namedsharding-gives-a-way-to-express-shardings-with-names">https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#namedsharding-gives-a-way-to-express-shardings-with-names</a>) tutorial has more details and diagrams that explain how <a class="reference internal" href="#jax.sharding.Mesh" title="jax.sharding.Mesh"><code class="xref py py-class docutils literal notranslate"><span class="pre">Mesh</span></code></a> and <a class="reference internal" href="#jax.sharding.PartitionSpec" title="jax.sharding.PartitionSpec"><code class="xref py py-class docutils literal notranslate"><span class="pre">PartitionSpec</span></code></a> are used.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters<span class="colon">:</span></dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>mesh</strong> – A <a class="reference internal" href="#jax.sharding.Mesh" title="jax.sharding.Mesh"><code class="xref py py-class docutils literal notranslate"><span class="pre">jax.sharding.Mesh</span></code></a> object.</p></li> <li><p><strong>spec</strong> – A <a class="reference internal" href="#jax.sharding.PartitionSpec" title="jax.sharding.PartitionSpec"><code class="xref py py-class docutils literal notranslate"><span class="pre">jax.sharding.PartitionSpec</span></code></a> object.</p></li> </ul> </dd> </dl> <p class="rubric">Examples</p> <div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="kn">from</span><span class="w"> </span><span class="nn">jax.sharding</span><span class="w"> </span><span class="kn">import</span> <span class="n">Mesh</span> <span class="gp">>>> </span><span class="kn">from</span><span class="w"> </span><span class="nn">jax.sharding</span><span class="w"> </span><span class="kn">import</span> <span class="n">PartitionSpec</span> <span class="k">as</span> <span class="n">P</span> <span class="gp">>>> </span><span class="n">mesh</span> <span class="o">=</span> <span class="n">Mesh</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">devices</span><span class="p">())</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">4</span><span class="p">),</span> <span class="p">(</span><span class="s1">'x'</span><span class="p">,</span> <span class="s1">'y'</span><span class="p">))</span> <span class="gp">>>> </span><span class="n">spec</span> <span class="o">=</span> <span class="n">P</span><span class="p">(</span><span class="s1">'x'</span><span class="p">,</span> <span class="s1">'y'</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">named_sharding</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">sharding</span><span class="o">.</span><span class="n">NamedSharding</span><span class="p">(</span><span class="n">mesh</span><span class="p">,</span> <span class="n">spec</span><span class="p">)</span> </pre></div> </div> <dl class="py property"> <dt class="sig sig-object py" id="jax.sharding.NamedSharding.addressable_devices"> <em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">addressable_devices</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#set" title="(in Python v3.13)"><span class="pre">set</span></a><span class="p"><span class="pre">[</span></span><a class="reference internal" href="_autosummary/jax.Device.html#jax.Device" title="jax.Device"><span class="pre">Device</span></a><span class="p"><span class="pre">]</span></span></em><a class="reference external" href="https://github.com/jax-ml/jax/blob/main/jax/_src/named_sharding.py#L222-L230"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#jax.sharding.NamedSharding.addressable_devices" title="Link to this definition">#</a></dt> <dd><p>The set of devices in the <a class="reference internal" href="#jax.sharding.Sharding" title="jax.sharding.Sharding"><code class="xref py py-class docutils literal notranslate"><span class="pre">Sharding</span></code></a> that are addressable by the current process.</p> </dd></dl> <dl class="py property"> <dt class="sig sig-object py" id="jax.sharding.NamedSharding.device_set"> <em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">device_set</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#set" title="(in Python v3.13)"><span class="pre">set</span></a><span class="p"><span class="pre">[</span></span><a class="reference internal" href="_autosummary/jax.Device.html#jax.Device" title="jax.Device"><span class="pre">Device</span></a><span class="p"><span class="pre">]</span></span></em><a class="reference external" href="https://github.com/jax-ml/jax/blob/main/jax/_src/named_sharding.py#L193-L199"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#jax.sharding.NamedSharding.device_set" title="Link to this definition">#</a></dt> <dd><p>The set of devices that this <a class="reference internal" href="#jax.sharding.Sharding" title="jax.sharding.Sharding"><code class="xref py py-class docutils literal notranslate"><span class="pre">Sharding</span></code></a> spans.</p> <p>In multi-controller JAX, the set of devices is global, i.e., includes non-addressable devices from other processes.</p> </dd></dl> <dl class="py property"> <dt class="sig sig-object py" id="jax.sharding.NamedSharding.is_fully_addressable"> <em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">is_fully_addressable</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/functions.html#bool" title="(in Python v3.13)"><span class="pre">bool</span></a></em><a class="reference external" href="https://github.com/jax-ml/jax/blob/main/jax/_src/named_sharding.py#L207-L215"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#jax.sharding.NamedSharding.is_fully_addressable" title="Link to this definition">#</a></dt> <dd><p>Is this sharding fully addressable?</p> <p>A sharding is fully addressable if the current process can address all of the devices named in the <a class="reference internal" href="#jax.sharding.Sharding" title="jax.sharding.Sharding"><code class="xref py py-class docutils literal notranslate"><span class="pre">Sharding</span></code></a>. <code class="docutils literal notranslate"><span class="pre">is_fully_addressable</span></code> is equivalent to “is_local” in multi-process JAX.</p> </dd></dl> <dl class="py property"> <dt class="sig sig-object py" id="jax.sharding.NamedSharding.is_fully_replicated"> <em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">is_fully_replicated</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/functions.html#bool" title="(in Python v3.13)"><span class="pre">bool</span></a></em><a class="headerlink" href="#jax.sharding.NamedSharding.is_fully_replicated" title="Link to this definition">#</a></dt> <dd><p>Is this sharding fully replicated?</p> <p>A sharding is fully replicated if each device has a complete copy of the entire data.</p> </dd></dl> <dl class="py property"> <dt class="sig sig-object py" id="jax.sharding.NamedSharding.memory_kind"> <em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">memory_kind</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.13)"><span class="pre">str</span></a><span class="w"> </span><span class="p"><span class="pre">|</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/constants.html#None" title="(in Python v3.13)"><span class="pre">None</span></a></em><a class="reference external" href="https://github.com/jax-ml/jax/blob/main/jax/_src/named_sharding.py#L146-L149"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#jax.sharding.NamedSharding.memory_kind" title="Link to this definition">#</a></dt> <dd><p>Returns the memory kind of the sharding.</p> </dd></dl> <dl class="py property"> <dt class="sig sig-object py" id="jax.sharding.NamedSharding.mesh"> <em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">mesh</span></span><a class="headerlink" href="#jax.sharding.NamedSharding.mesh" title="Link to this definition">#</a></dt> <dd><p>(self) -> object</p> </dd></dl> <dl class="py property"> <dt class="sig sig-object py" id="jax.sharding.NamedSharding.num_devices"> <em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">num_devices</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/functions.html#int" title="(in Python v3.13)"><span class="pre">int</span></a></em><a class="reference external" href="https://github.com/jax-ml/jax/blob/main/jax/_src/named_sharding.py#L189-L192"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#jax.sharding.NamedSharding.num_devices" title="Link to this definition">#</a></dt> <dd><p>Number of devices that the sharding contains.</p> </dd></dl> <dl class="py property"> <dt class="sig sig-object py" id="jax.sharding.NamedSharding.spec"> <em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">spec</span></span><a class="headerlink" href="#jax.sharding.NamedSharding.spec" title="Link to this definition">#</a></dt> <dd><p>(self) -> object</p> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="jax.sharding.NamedSharding.with_memory_kind"> <span class="sig-name descname"><span class="pre">with_memory_kind</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">kind</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/jax-ml/jax/blob/main/jax/_src/named_sharding.py#L242-L244"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#jax.sharding.NamedSharding.with_memory_kind" title="Link to this definition">#</a></dt> <dd><p>Returns a new Sharding instance with the specified memory kind.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters<span class="colon">:</span></dt> <dd class="field-odd"><p><strong>kind</strong> (<a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.13)"><em>str</em></a>)</p> </dd> <dt class="field-even">Return type<span class="colon">:</span></dt> <dd class="field-even"><p><a class="reference internal" href="#jax.sharding.NamedSharding" title="jax.sharding.NamedSharding">NamedSharding</a></p> </dd> </dl> </dd></dl> </dd></dl> <dl class="py class"> <dt class="sig sig-object py" id="jax.sharding.PositionalSharding"> <em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">jax.sharding.</span></span><span class="sig-name descname"><span class="pre">PositionalSharding</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">devices</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">memory_kind</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/jax-ml/jax/blob/main/jax/_src/sharding_impls.py#L419-L545"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#jax.sharding.PositionalSharding" title="Link to this definition">#</a></dt> <dd><p>Bases: <a class="reference internal" href="#jax.sharding.Sharding" title="jaxlib.xla_extension.Sharding"><code class="xref py py-class docutils literal notranslate"><span class="pre">Sharding</span></code></a></p> <dl class="field-list simple"> <dt class="field-odd">Parameters<span class="colon">:</span></dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>devices</strong> (<em>Sequence</em><em>[</em><em>xc.Device</em><em>] </em><em>| </em><em>np.ndarray</em>)</p></li> <li><p><strong>memory_kind</strong> (<a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.13)"><em>str</em></a><em> | </em><em>None</em>)</p></li> </ul> </dd> </dl> <dl class="py property"> <dt class="sig sig-object py" id="jax.sharding.PositionalSharding.device_set"> <em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">device_set</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#set" title="(in Python v3.13)"><span class="pre">set</span></a><span class="p"><span class="pre">[</span></span><span class="pre">xc.Device</span><span class="p"><span class="pre">]</span></span></em><a class="headerlink" href="#jax.sharding.PositionalSharding.device_set" title="Link to this definition">#</a></dt> <dd><p>The set of devices that this <a class="reference internal" href="#jax.sharding.Sharding" title="jax.sharding.Sharding"><code class="xref py py-class docutils literal notranslate"><span class="pre">Sharding</span></code></a> spans.</p> <p>In multi-controller JAX, the set of devices is global, i.e., includes non-addressable devices from other processes.</p> </dd></dl> <dl class="py property"> <dt class="sig sig-object py" id="jax.sharding.PositionalSharding.is_fully_addressable"> <em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">is_fully_addressable</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/functions.html#bool" title="(in Python v3.13)"><span class="pre">bool</span></a></em><a class="headerlink" href="#jax.sharding.PositionalSharding.is_fully_addressable" title="Link to this definition">#</a></dt> <dd><p>Is this sharding fully addressable?</p> <p>A sharding is fully addressable if the current process can address all of the devices named in the <a class="reference internal" href="#jax.sharding.Sharding" title="jax.sharding.Sharding"><code class="xref py py-class docutils literal notranslate"><span class="pre">Sharding</span></code></a>. <code class="docutils literal notranslate"><span class="pre">is_fully_addressable</span></code> is equivalent to “is_local” in multi-process JAX.</p> </dd></dl> <dl class="py property"> <dt class="sig sig-object py" id="jax.sharding.PositionalSharding.is_fully_replicated"> <em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">is_fully_replicated</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/functions.html#bool" title="(in Python v3.13)"><span class="pre">bool</span></a></em><a class="headerlink" href="#jax.sharding.PositionalSharding.is_fully_replicated" title="Link to this definition">#</a></dt> <dd><p>Is this sharding fully replicated?</p> <p>A sharding is fully replicated if each device has a complete copy of the entire data.</p> </dd></dl> <dl class="py property"> <dt class="sig sig-object py" id="jax.sharding.PositionalSharding.memory_kind"> <em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">memory_kind</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.13)"><span class="pre">str</span></a><span class="w"> </span><span class="p"><span class="pre">|</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/constants.html#None" title="(in Python v3.13)"><span class="pre">None</span></a></em><a class="reference external" href="https://github.com/jax-ml/jax/blob/main/jax/_src/sharding_impls.py#L518-L521"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#jax.sharding.PositionalSharding.memory_kind" title="Link to this definition">#</a></dt> <dd><p>Returns the memory kind of the sharding.</p> </dd></dl> <dl class="py property"> <dt class="sig sig-object py" id="jax.sharding.PositionalSharding.num_devices"> <em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">num_devices</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/functions.html#int" title="(in Python v3.13)"><span class="pre">int</span></a></em><a class="reference external" href="https://github.com/jax-ml/jax/blob/main/jax/_src/sharding_impls.py#L510-L513"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#jax.sharding.PositionalSharding.num_devices" title="Link to this definition">#</a></dt> <dd><p>Number of devices that the sharding contains.</p> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="jax.sharding.PositionalSharding.with_memory_kind"> <span class="sig-name descname"><span class="pre">with_memory_kind</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">kind</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/jax-ml/jax/blob/main/jax/_src/sharding_impls.py#L522-L524"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#jax.sharding.PositionalSharding.with_memory_kind" title="Link to this definition">#</a></dt> <dd><p>Returns a new Sharding instance with the specified memory kind.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters<span class="colon">:</span></dt> <dd class="field-odd"><p><strong>kind</strong> (<a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.13)"><em>str</em></a>)</p> </dd> <dt class="field-even">Return type<span class="colon">:</span></dt> <dd class="field-even"><p><a class="reference internal" href="#jax.sharding.PositionalSharding" title="jax.sharding.PositionalSharding">PositionalSharding</a></p> </dd> </dl> </dd></dl> </dd></dl> <dl class="py class"> <dt class="sig sig-object py" id="jax.sharding.PmapSharding"> <em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">jax.sharding.</span></span><span class="sig-name descname"><span class="pre">PmapSharding</span></span><a class="headerlink" href="#jax.sharding.PmapSharding" title="Link to this definition">#</a></dt> <dd><p>Bases: <a class="reference internal" href="#jax.sharding.Sharding" title="jaxlib.xla_extension.Sharding"><code class="xref py py-class docutils literal notranslate"><span class="pre">Sharding</span></code></a></p> <p>Describes a sharding used by <a class="reference internal" href="_autosummary/jax.pmap.html#jax.pmap" title="jax.pmap"><code class="xref py py-func docutils literal notranslate"><span class="pre">jax.pmap()</span></code></a>.</p> <dl class="py method"> <dt class="sig sig-object py" id="jax.sharding.PmapSharding.default"> <em class="property"><span class="pre">classmethod</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">default</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">shape</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">sharded_dim</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">devices</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/jax-ml/jax/blob/main/jax/_src/sharding_impls.py#L256-L300"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#jax.sharding.PmapSharding.default" title="Link to this definition">#</a></dt> <dd><p>Creates a <a class="reference internal" href="#jax.sharding.PmapSharding" title="jax.sharding.PmapSharding"><code class="xref py py-class docutils literal notranslate"><span class="pre">PmapSharding</span></code></a> which matches the default placement used by <a class="reference internal" href="_autosummary/jax.pmap.html#jax.pmap" title="jax.pmap"><code class="xref py py-func docutils literal notranslate"><span class="pre">jax.pmap()</span></code></a>.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters<span class="colon">:</span></dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>shape</strong> (<em>Shape</em>) – The shape of the input array.</p></li> <li><p><strong>sharded_dim</strong> (<a class="reference external" href="https://docs.python.org/3/library/functions.html#int" title="(in Python v3.13)"><em>int</em></a><em> | </em><em>None</em>) – Dimension the input array is sharded on. Defaults to 0.</p></li> <li><p><strong>devices</strong> (<em>Sequence</em><em>[</em><em>xc.Device</em><em>] </em><em>| </em><em>None</em><em> | </em><em>None</em>) – Optional sequence of devices to use. If omitted, the implicit device order used by pmap is used, which is the order of <a class="reference internal" href="_autosummary/jax.local_devices.html#jax.local_devices" title="jax.local_devices"><code class="xref py py-func docutils literal notranslate"><span class="pre">jax.local_devices()</span></code></a>.</p></li> </ul> </dd> <dt class="field-even">Return type<span class="colon">:</span></dt> <dd class="field-even"><p><a class="reference internal" href="#jax.sharding.PmapSharding" title="jax.sharding.PmapSharding">PmapSharding</a></p> </dd> </dl> </dd></dl> <dl class="py property"> <dt class="sig sig-object py" id="jax.sharding.PmapSharding.device_set"> <em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">device_set</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#set" title="(in Python v3.13)"><span class="pre">set</span></a><span class="p"><span class="pre">[</span></span><a class="reference internal" href="_autosummary/jax.Device.html#jax.Device" title="jax.Device"><span class="pre">Device</span></a><span class="p"><span class="pre">]</span></span></em><a class="headerlink" href="#jax.sharding.PmapSharding.device_set" title="Link to this definition">#</a></dt> <dd><p>The set of devices that this <a class="reference internal" href="#jax.sharding.Sharding" title="jax.sharding.Sharding"><code class="xref py py-class docutils literal notranslate"><span class="pre">Sharding</span></code></a> spans.</p> <p>In multi-controller JAX, the set of devices is global, i.e., includes non-addressable devices from other processes.</p> </dd></dl> <dl class="py property"> <dt class="sig sig-object py" id="jax.sharding.PmapSharding.devices"> <em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">devices</span></span><a class="headerlink" href="#jax.sharding.PmapSharding.devices" title="Link to this definition">#</a></dt> <dd><p>(self) -> ndarray</p> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="jax.sharding.PmapSharding.devices_indices_map"> <span class="sig-name descname"><span class="pre">devices_indices_map</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">global_shape</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/jax-ml/jax/blob/main/jax/_src/sharding_impls.py#L309-L311"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#jax.sharding.PmapSharding.devices_indices_map" title="Link to this definition">#</a></dt> <dd><p>Returns a mapping from devices to the array slices each contains.</p> <p>The mapping includes all global devices, i.e., including non-addressable devices from other processes.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters<span class="colon">:</span></dt> <dd class="field-odd"><p><strong>global_shape</strong> (<em>Shape</em>)</p> </dd> <dt class="field-even">Return type<span class="colon">:</span></dt> <dd class="field-even"><p>Mapping[<a class="reference internal" href="_autosummary/jax.Device.html#jax.Device" title="jax.Device">Device</a>, Index]</p> </dd> </dl> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="jax.sharding.PmapSharding.is_equivalent_to"> <span class="sig-name descname"><span class="pre">is_equivalent_to</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">other</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">ndim</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/jax-ml/jax/blob/main/jax/_src/sharding_impls.py#L251-L254"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#jax.sharding.PmapSharding.is_equivalent_to" title="Link to this definition">#</a></dt> <dd><p>Returns <code class="docutils literal notranslate"><span class="pre">True</span></code> if two shardings are equivalent.</p> <p>Two shardings are equivalent if they place the same logical array shards on the same devices.</p> <p>For example, a <a class="reference internal" href="#jax.sharding.NamedSharding" title="jax.sharding.NamedSharding"><code class="xref py py-class docutils literal notranslate"><span class="pre">NamedSharding</span></code></a> may be equivalent to a <a class="reference internal" href="#jax.sharding.PositionalSharding" title="jax.sharding.PositionalSharding"><code class="xref py py-class docutils literal notranslate"><span class="pre">PositionalSharding</span></code></a> if both place the same shards of the array on the same devices.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters<span class="colon">:</span></dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>self</strong> (<a class="reference internal" href="#jax.sharding.PmapSharding" title="jax.sharding.PmapSharding"><em>PmapSharding</em></a>)</p></li> <li><p><strong>other</strong> (<a class="reference internal" href="#jax.sharding.PmapSharding" title="jax.sharding.PmapSharding"><em>PmapSharding</em></a>)</p></li> <li><p><strong>ndim</strong> (<a class="reference external" href="https://docs.python.org/3/library/functions.html#int" title="(in Python v3.13)"><em>int</em></a>)</p></li> </ul> </dd> <dt class="field-even">Return type<span class="colon">:</span></dt> <dd class="field-even"><p><a class="reference external" href="https://docs.python.org/3/library/functions.html#bool" title="(in Python v3.13)">bool</a></p> </dd> </dl> </dd></dl> <dl class="py property"> <dt class="sig sig-object py" id="jax.sharding.PmapSharding.is_fully_addressable"> <em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">is_fully_addressable</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/functions.html#bool" title="(in Python v3.13)"><span class="pre">bool</span></a></em><a class="headerlink" href="#jax.sharding.PmapSharding.is_fully_addressable" title="Link to this definition">#</a></dt> <dd><p>Is this sharding fully addressable?</p> <p>A sharding is fully addressable if the current process can address all of the devices named in the <a class="reference internal" href="#jax.sharding.Sharding" title="jax.sharding.Sharding"><code class="xref py py-class docutils literal notranslate"><span class="pre">Sharding</span></code></a>. <code class="docutils literal notranslate"><span class="pre">is_fully_addressable</span></code> is equivalent to “is_local” in multi-process JAX.</p> </dd></dl> <dl class="py property"> <dt class="sig sig-object py" id="jax.sharding.PmapSharding.is_fully_replicated"> <em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">is_fully_replicated</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/functions.html#bool" title="(in Python v3.13)"><span class="pre">bool</span></a></em><a class="headerlink" href="#jax.sharding.PmapSharding.is_fully_replicated" title="Link to this definition">#</a></dt> <dd><p>Is this sharding fully replicated?</p> <p>A sharding is fully replicated if each device has a complete copy of the entire data.</p> </dd></dl> <dl class="py property"> <dt class="sig sig-object py" id="jax.sharding.PmapSharding.memory_kind"> <em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">memory_kind</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.13)"><span class="pre">str</span></a><span class="w"> </span><span class="p"><span class="pre">|</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/constants.html#None" title="(in Python v3.13)"><span class="pre">None</span></a></em><a class="reference external" href="https://github.com/jax-ml/jax/blob/main/jax/_src/sharding_impls.py#L316-L322"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#jax.sharding.PmapSharding.memory_kind" title="Link to this definition">#</a></dt> <dd><p>Returns the memory kind of the sharding.</p> </dd></dl> <dl class="py property"> <dt class="sig sig-object py" id="jax.sharding.PmapSharding.num_devices"> <em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">num_devices</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/functions.html#int" title="(in Python v3.13)"><span class="pre">int</span></a></em><a class="reference external" href="https://github.com/jax-ml/jax/blob/main/jax/_src/sharding_impls.py#L301-L304"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#jax.sharding.PmapSharding.num_devices" title="Link to this definition">#</a></dt> <dd><p>Number of devices that the sharding contains.</p> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="jax.sharding.PmapSharding.shard_shape"> <span class="sig-name descname"><span class="pre">shard_shape</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">global_shape</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/jax-ml/jax/blob/main/jax/_src/sharding_impls.py#L343-L367"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#jax.sharding.PmapSharding.shard_shape" title="Link to this definition">#</a></dt> <dd><p>Returns the shape of the data on each device.</p> <p>The shard shape returned by this function is calculated from <code class="docutils literal notranslate"><span class="pre">global_shape</span></code> and the properties of the sharding.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters<span class="colon">:</span></dt> <dd class="field-odd"><p><strong>global_shape</strong> (<em>Shape</em>)</p> </dd> <dt class="field-even">Return type<span class="colon">:</span></dt> <dd class="field-even"><p>Shape</p> </dd> </dl> </dd></dl> <dl class="py property"> <dt class="sig sig-object py" id="jax.sharding.PmapSharding.sharding_spec"> <em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">sharding_spec</span></span><a class="headerlink" href="#jax.sharding.PmapSharding.sharding_spec" title="Link to this definition">#</a></dt> <dd><p>(self) -> jax::ShardingSpec</p> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="jax.sharding.PmapSharding.with_memory_kind"> <span class="sig-name descname"><span class="pre">with_memory_kind</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">kind</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/jax-ml/jax/blob/main/jax/_src/sharding_impls.py#L323-L325"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#jax.sharding.PmapSharding.with_memory_kind" title="Link to this definition">#</a></dt> <dd><p>Returns a new Sharding instance with the specified memory kind.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters<span class="colon">:</span></dt> <dd class="field-odd"><p><strong>kind</strong> (<a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.13)"><em>str</em></a>)</p> </dd> </dl> </dd></dl> </dd></dl> <dl class="py class"> <dt class="sig sig-object py" id="jax.sharding.GSPMDSharding"> <em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">jax.sharding.</span></span><span class="sig-name descname"><span class="pre">GSPMDSharding</span></span><a class="headerlink" href="#jax.sharding.GSPMDSharding" title="Link to this definition">#</a></dt> <dd><p>Bases: <a class="reference internal" href="#jax.sharding.Sharding" title="jaxlib.xla_extension.Sharding"><code class="xref py py-class docutils literal notranslate"><span class="pre">Sharding</span></code></a></p> <dl class="py property"> <dt class="sig sig-object py" id="jax.sharding.GSPMDSharding.device_set"> <em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">device_set</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#set" title="(in Python v3.13)"><span class="pre">set</span></a><span class="p"><span class="pre">[</span></span><a class="reference internal" href="_autosummary/jax.Device.html#jax.Device" title="jax.Device"><span class="pre">Device</span></a><span class="p"><span class="pre">]</span></span></em><a class="headerlink" href="#jax.sharding.GSPMDSharding.device_set" title="Link to this definition">#</a></dt> <dd><p>The set of devices that this <a class="reference internal" href="#jax.sharding.Sharding" title="jax.sharding.Sharding"><code class="xref py py-class docutils literal notranslate"><span class="pre">Sharding</span></code></a> spans.</p> <p>In multi-controller JAX, the set of devices is global, i.e., includes non-addressable devices from other processes.</p> </dd></dl> <dl class="py property"> <dt class="sig sig-object py" id="jax.sharding.GSPMDSharding.is_fully_addressable"> <em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">is_fully_addressable</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/functions.html#bool" title="(in Python v3.13)"><span class="pre">bool</span></a></em><a class="headerlink" href="#jax.sharding.GSPMDSharding.is_fully_addressable" title="Link to this definition">#</a></dt> <dd><p>Is this sharding fully addressable?</p> <p>A sharding is fully addressable if the current process can address all of the devices named in the <a class="reference internal" href="#jax.sharding.Sharding" title="jax.sharding.Sharding"><code class="xref py py-class docutils literal notranslate"><span class="pre">Sharding</span></code></a>. <code class="docutils literal notranslate"><span class="pre">is_fully_addressable</span></code> is equivalent to “is_local” in multi-process JAX.</p> </dd></dl> <dl class="py property"> <dt class="sig sig-object py" id="jax.sharding.GSPMDSharding.is_fully_replicated"> <em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">is_fully_replicated</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/functions.html#bool" title="(in Python v3.13)"><span class="pre">bool</span></a></em><a class="headerlink" href="#jax.sharding.GSPMDSharding.is_fully_replicated" title="Link to this definition">#</a></dt> <dd><p>Is this sharding fully replicated?</p> <p>A sharding is fully replicated if each device has a complete copy of the entire data.</p> </dd></dl> <dl class="py property"> <dt class="sig sig-object py" id="jax.sharding.GSPMDSharding.memory_kind"> <em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">memory_kind</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.13)"><span class="pre">str</span></a><span class="w"> </span><span class="p"><span class="pre">|</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/constants.html#None" title="(in Python v3.13)"><span class="pre">None</span></a></em><a class="reference external" href="https://github.com/jax-ml/jax/blob/main/jax/_src/sharding_impls.py#L641-L644"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#jax.sharding.GSPMDSharding.memory_kind" title="Link to this definition">#</a></dt> <dd><p>Returns the memory kind of the sharding.</p> </dd></dl> <dl class="py property"> <dt class="sig sig-object py" id="jax.sharding.GSPMDSharding.num_devices"> <em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">num_devices</span></span><em class="property"><span class="p"><span class="pre">:</span></span><span class="w"> </span><a class="reference external" href="https://docs.python.org/3/library/functions.html#int" title="(in Python v3.13)"><span class="pre">int</span></a></em><a class="reference external" href="https://github.com/jax-ml/jax/blob/main/jax/_src/sharding_impls.py#L633-L636"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#jax.sharding.GSPMDSharding.num_devices" title="Link to this definition">#</a></dt> <dd><p>Number of devices that the sharding contains.</p> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="jax.sharding.GSPMDSharding.with_memory_kind"> <span class="sig-name descname"><span class="pre">with_memory_kind</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">kind</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/jax-ml/jax/blob/main/jax/_src/sharding_impls.py#L645-L647"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#jax.sharding.GSPMDSharding.with_memory_kind" title="Link to this definition">#</a></dt> <dd><p>Returns a new Sharding instance with the specified memory kind.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters<span class="colon">:</span></dt> <dd class="field-odd"><p><strong>kind</strong> (<a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#str" title="(in Python v3.13)"><em>str</em></a>)</p> </dd> <dt class="field-even">Return type<span class="colon">:</span></dt> <dd class="field-even"><p><a class="reference internal" href="#jax.sharding.GSPMDSharding" title="jax.sharding.GSPMDSharding">GSPMDSharding</a></p> </dd> </dl> </dd></dl> </dd></dl> <dl class="py class"> <dt class="sig sig-object py" id="jax.sharding.PartitionSpec"> <em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">jax.sharding.</span></span><span class="sig-name descname"><span class="pre">PartitionSpec</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">partitions</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/jax-ml/jax/blob/main/jax/_src/partition_spec.py#L35-L77"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#jax.sharding.PartitionSpec" title="Link to this definition">#</a></dt> <dd><p>Tuple describing how to partition an array across a mesh of devices.</p> <p>Each element is either <code class="docutils literal notranslate"><span class="pre">None</span></code>, a string, or a tuple of strings. See the documentation of <a class="reference internal" href="#jax.sharding.NamedSharding" title="jax.sharding.NamedSharding"><code class="xref py py-class docutils literal notranslate"><span class="pre">jax.sharding.NamedSharding</span></code></a> for more details.</p> <p>This class exists so JAX’s pytree utilities can distinguish a partition specifications from tuples that should be treated as pytrees.</p> </dd></dl> <dl class="py class"> <dt class="sig sig-object py" id="jax.sharding.Mesh"> <em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">jax.sharding.</span></span><span class="sig-name descname"><span class="pre">Mesh</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">devices</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">axis_names</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">axis_types</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="reference external" href="https://github.com/jax-ml/jax/blob/main/jax/_src/mesh.py#L197-L420"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#jax.sharding.Mesh" title="Link to this definition">#</a></dt> <dd><p>Declare the hardware resources available in the scope of this manager.</p> <p>In particular, all <code class="docutils literal notranslate"><span class="pre">axis_names</span></code> become valid resource names inside the managed block and can be used e.g. in the <code class="docutils literal notranslate"><span class="pre">in_axis_resources</span></code> argument of <a class="reference internal" href="jax.experimental.pjit.html#jax.experimental.pjit.pjit" title="jax.experimental.pjit.pjit"><code class="xref py py-func docutils literal notranslate"><span class="pre">jax.experimental.pjit.pjit()</span></code></a>. Also see JAX’s multi-process programming model (<a class="reference external" href="https://jax.readthedocs.io/en/latest/multi_process.html">https://jax.readthedocs.io/en/latest/multi_process.html</a>) and the Distributed arrays and automatic parallelization tutorial (<a class="reference external" href="https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html">https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html</a>)</p> <p>If you are compiling in multiple threads, make sure that the <code class="docutils literal notranslate"><span class="pre">with</span> <span class="pre">Mesh</span></code> context manager is inside the function that the threads will execute.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters<span class="colon">:</span></dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>devices</strong> (<em>np.ndarray</em>) – A NumPy ndarray object containing JAX device objects (as obtained e.g. from <a class="reference internal" href="_autosummary/jax.devices.html#jax.devices" title="jax.devices"><code class="xref py py-func docutils literal notranslate"><span class="pre">jax.devices()</span></code></a>).</p></li> <li><p><strong>axis_names</strong> (<a class="reference external" href="https://docs.python.org/3/library/stdtypes.html#tuple" title="(in Python v3.13)"><em>tuple</em></a><em>[</em><em>MeshAxisName</em><em>, </em><em>...</em><em>]</em>) – A sequence of resource axis names to be assigned to the dimensions of the <code class="docutils literal notranslate"><span class="pre">devices</span></code> argument. Its length should match the rank of <code class="docutils literal notranslate"><span class="pre">devices</span></code>.</p></li> <li><p><strong>axis_types</strong> (<em>MeshAxisType</em><em> | </em><em>None</em>)</p></li> </ul> </dd> </dl> <p class="rubric">Examples</p> <div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="kn">from</span><span class="w"> </span><span class="nn">jax.experimental.pjit</span><span class="w"> </span><span class="kn">import</span> <span class="n">pjit</span> <span class="gp">>>> </span><span class="kn">from</span><span class="w"> </span><span class="nn">jax.sharding</span><span class="w"> </span><span class="kn">import</span> <span class="n">Mesh</span> <span class="gp">>>> </span><span class="kn">from</span><span class="w"> </span><span class="nn">jax.sharding</span><span class="w"> </span><span class="kn">import</span> <span class="n">PartitionSpec</span> <span class="k">as</span> <span class="n">P</span> <span class="gp">>>> </span><span class="kn">import</span><span class="w"> </span><span class="nn">numpy</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">np</span> <span class="gp">...</span> <span class="gp">>>> </span><span class="n">inp</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">16</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="mi">8</span><span class="p">,</span> <span class="mi">2</span><span class="p">))</span> <span class="gp">>>> </span><span class="n">devices</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">devices</span><span class="p">())</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span> <span class="gp">...</span> <span class="gp">>>> </span><span class="c1"># Declare a 2D mesh with axes `x` and `y`.</span> <span class="gp">>>> </span><span class="n">global_mesh</span> <span class="o">=</span> <span class="n">Mesh</span><span class="p">(</span><span class="n">devices</span><span class="p">,</span> <span class="p">(</span><span class="s1">'x'</span><span class="p">,</span> <span class="s1">'y'</span><span class="p">))</span> <span class="gp">>>> </span><span class="c1"># Use the mesh object directly as a context manager.</span> <span class="gp">>>> </span><span class="k">with</span> <span class="n">global_mesh</span><span class="p">:</span> <span class="gp">... </span> <span class="n">out</span> <span class="o">=</span> <span class="n">pjit</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">x</span><span class="p">,</span> <span class="n">in_shardings</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">out_shardings</span><span class="o">=</span><span class="kc">None</span><span class="p">)(</span><span class="n">inp</span><span class="p">)</span> </pre></div> </div> <div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="c1"># Initialize the Mesh and use the mesh as the context manager.</span> <span class="gp">>>> </span><span class="k">with</span> <span class="n">Mesh</span><span class="p">(</span><span class="n">devices</span><span class="p">,</span> <span class="p">(</span><span class="s1">'x'</span><span class="p">,</span> <span class="s1">'y'</span><span class="p">))</span> <span class="k">as</span> <span class="n">global_mesh</span><span class="p">:</span> <span class="gp">... </span> <span class="n">out</span> <span class="o">=</span> <span class="n">pjit</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">x</span><span class="p">,</span> <span class="n">in_shardings</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">out_shardings</span><span class="o">=</span><span class="kc">None</span><span class="p">)(</span><span class="n">inp</span><span class="p">)</span> </pre></div> </div> <div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="c1"># Also you can use it as `with ... as ...`.</span> <span class="gp">>>> </span><span class="n">global_mesh</span> <span class="o">=</span> <span class="n">Mesh</span><span class="p">(</span><span class="n">devices</span><span class="p">,</span> <span class="p">(</span><span class="s1">'x'</span><span class="p">,</span> <span class="s1">'y'</span><span class="p">))</span> <span class="gp">>>> </span><span class="k">with</span> <span class="n">global_mesh</span> <span class="k">as</span> <span class="n">m</span><span class="p">:</span> <span class="gp">... </span> <span class="n">out</span> <span class="o">=</span> <span class="n">pjit</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">x</span><span class="p">,</span> <span class="n">in_shardings</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">out_shardings</span><span class="o">=</span><span class="kc">None</span><span class="p">)(</span><span class="n">inp</span><span class="p">)</span> </pre></div> </div> <div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="c1"># You can also use it as `with Mesh(...)`.</span> <span class="gp">>>> </span><span class="k">with</span> <span class="n">Mesh</span><span class="p">(</span><span class="n">devices</span><span class="p">,</span> <span class="p">(</span><span class="s1">'x'</span><span class="p">,</span> <span class="s1">'y'</span><span class="p">)):</span> <span class="gp">... </span> <span class="n">out</span> <span class="o">=</span> <span class="n">pjit</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">x</span><span class="p">,</span> <span class="n">in_shardings</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">out_shardings</span><span class="o">=</span><span class="kc">None</span><span class="p">)(</span><span class="n">inp</span><span class="p">)</span> </pre></div> </div> </dd></dl> </section> </section> </article> <footer class="prev-next-footer"> <div class="prev-next-area"> <a class="left-prev" href="_autosummary/jax.random.weibull_min.html" title="previous page"> <i class="fa-solid fa-angle-left"></i> <div class="prev-next-info"> <p class="prev-next-subtitle">previous</p> <p class="prev-next-title">jax.random.weibull_min</p> </div> </a> <a class="right-next" href="jax.debug.html" title="next page"> <div class="prev-next-info"> <p class="prev-next-subtitle">next</p> <p class="prev-next-title"><code class="docutils literal notranslate"><span class="pre">jax.debug</span></code> module</p> </div> <i class="fa-solid fa-angle-right"></i> </a> </div> </footer> </div> <div class="bd-sidebar-secondary bd-toc"><div class="sidebar-secondary-items sidebar-secondary__inner"> <div class="sidebar-secondary-item"> <div class="page-toc tocsection onthispage"> <i class="fa-solid fa-list"></i> Contents </div> <nav class="bd-toc-nav page-toc"> <ul class="visible nav section-nav flex-column"> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#classes">Classes</a><ul class="visible nav section-nav flex-column"> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.Sharding"><code class="docutils literal notranslate"><span class="pre">Sharding</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.Sharding.addressable_devices"><code class="docutils literal notranslate"><span class="pre">Sharding.addressable_devices</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.Sharding.addressable_devices_indices_map"><code class="docutils literal notranslate"><span class="pre">Sharding.addressable_devices_indices_map()</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.Sharding.device_set"><code class="docutils literal notranslate"><span class="pre">Sharding.device_set</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.Sharding.devices_indices_map"><code class="docutils literal notranslate"><span class="pre">Sharding.devices_indices_map()</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.Sharding.is_equivalent_to"><code class="docutils literal notranslate"><span class="pre">Sharding.is_equivalent_to()</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.Sharding.is_fully_addressable"><code class="docutils literal notranslate"><span class="pre">Sharding.is_fully_addressable</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.Sharding.is_fully_replicated"><code class="docutils literal notranslate"><span class="pre">Sharding.is_fully_replicated</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.Sharding.memory_kind"><code class="docutils literal notranslate"><span class="pre">Sharding.memory_kind</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.Sharding.num_devices"><code class="docutils literal notranslate"><span class="pre">Sharding.num_devices</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.Sharding.shard_shape"><code class="docutils literal notranslate"><span class="pre">Sharding.shard_shape()</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.Sharding.with_memory_kind"><code class="docutils literal notranslate"><span class="pre">Sharding.with_memory_kind()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.SingleDeviceSharding"><code class="docutils literal notranslate"><span class="pre">SingleDeviceSharding</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.SingleDeviceSharding.device_set"><code class="docutils literal notranslate"><span class="pre">SingleDeviceSharding.device_set</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.SingleDeviceSharding.devices_indices_map"><code class="docutils literal notranslate"><span class="pre">SingleDeviceSharding.devices_indices_map()</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.SingleDeviceSharding.is_fully_addressable"><code class="docutils literal notranslate"><span class="pre">SingleDeviceSharding.is_fully_addressable</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.SingleDeviceSharding.is_fully_replicated"><code class="docutils literal notranslate"><span class="pre">SingleDeviceSharding.is_fully_replicated</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.SingleDeviceSharding.memory_kind"><code class="docutils literal notranslate"><span class="pre">SingleDeviceSharding.memory_kind</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.SingleDeviceSharding.num_devices"><code class="docutils literal notranslate"><span class="pre">SingleDeviceSharding.num_devices</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.SingleDeviceSharding.with_memory_kind"><code class="docutils literal notranslate"><span class="pre">SingleDeviceSharding.with_memory_kind()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.NamedSharding"><code class="docutils literal notranslate"><span class="pre">NamedSharding</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.NamedSharding.addressable_devices"><code class="docutils literal notranslate"><span class="pre">NamedSharding.addressable_devices</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.NamedSharding.device_set"><code class="docutils literal notranslate"><span class="pre">NamedSharding.device_set</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.NamedSharding.is_fully_addressable"><code class="docutils literal notranslate"><span class="pre">NamedSharding.is_fully_addressable</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.NamedSharding.is_fully_replicated"><code class="docutils literal notranslate"><span class="pre">NamedSharding.is_fully_replicated</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.NamedSharding.memory_kind"><code class="docutils literal notranslate"><span class="pre">NamedSharding.memory_kind</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.NamedSharding.mesh"><code class="docutils literal notranslate"><span class="pre">NamedSharding.mesh</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.NamedSharding.num_devices"><code class="docutils literal notranslate"><span class="pre">NamedSharding.num_devices</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.NamedSharding.spec"><code class="docutils literal notranslate"><span class="pre">NamedSharding.spec</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.NamedSharding.with_memory_kind"><code class="docutils literal notranslate"><span class="pre">NamedSharding.with_memory_kind()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.PositionalSharding"><code class="docutils literal notranslate"><span class="pre">PositionalSharding</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.PositionalSharding.device_set"><code class="docutils literal notranslate"><span class="pre">PositionalSharding.device_set</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.PositionalSharding.is_fully_addressable"><code class="docutils literal notranslate"><span class="pre">PositionalSharding.is_fully_addressable</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.PositionalSharding.is_fully_replicated"><code class="docutils literal notranslate"><span class="pre">PositionalSharding.is_fully_replicated</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.PositionalSharding.memory_kind"><code class="docutils literal notranslate"><span class="pre">PositionalSharding.memory_kind</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.PositionalSharding.num_devices"><code class="docutils literal notranslate"><span class="pre">PositionalSharding.num_devices</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.PositionalSharding.with_memory_kind"><code class="docutils literal notranslate"><span class="pre">PositionalSharding.with_memory_kind()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.PmapSharding"><code class="docutils literal notranslate"><span class="pre">PmapSharding</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.PmapSharding.default"><code class="docutils literal notranslate"><span class="pre">PmapSharding.default()</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.PmapSharding.device_set"><code class="docutils literal notranslate"><span class="pre">PmapSharding.device_set</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.PmapSharding.devices"><code class="docutils literal notranslate"><span class="pre">PmapSharding.devices</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.PmapSharding.devices_indices_map"><code class="docutils literal notranslate"><span class="pre">PmapSharding.devices_indices_map()</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.PmapSharding.is_equivalent_to"><code class="docutils literal notranslate"><span class="pre">PmapSharding.is_equivalent_to()</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.PmapSharding.is_fully_addressable"><code class="docutils literal notranslate"><span class="pre">PmapSharding.is_fully_addressable</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.PmapSharding.is_fully_replicated"><code class="docutils literal notranslate"><span class="pre">PmapSharding.is_fully_replicated</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.PmapSharding.memory_kind"><code class="docutils literal notranslate"><span class="pre">PmapSharding.memory_kind</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.PmapSharding.num_devices"><code class="docutils literal notranslate"><span class="pre">PmapSharding.num_devices</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.PmapSharding.shard_shape"><code class="docutils literal notranslate"><span class="pre">PmapSharding.shard_shape()</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.PmapSharding.sharding_spec"><code class="docutils literal notranslate"><span class="pre">PmapSharding.sharding_spec</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.PmapSharding.with_memory_kind"><code class="docutils literal notranslate"><span class="pre">PmapSharding.with_memory_kind()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.GSPMDSharding"><code class="docutils literal notranslate"><span class="pre">GSPMDSharding</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.GSPMDSharding.device_set"><code class="docutils literal notranslate"><span class="pre">GSPMDSharding.device_set</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.GSPMDSharding.is_fully_addressable"><code class="docutils literal notranslate"><span class="pre">GSPMDSharding.is_fully_addressable</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.GSPMDSharding.is_fully_replicated"><code class="docutils literal notranslate"><span class="pre">GSPMDSharding.is_fully_replicated</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.GSPMDSharding.memory_kind"><code class="docutils literal notranslate"><span class="pre">GSPMDSharding.memory_kind</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.GSPMDSharding.num_devices"><code class="docutils literal notranslate"><span class="pre">GSPMDSharding.num_devices</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.GSPMDSharding.with_memory_kind"><code class="docutils literal notranslate"><span class="pre">GSPMDSharding.with_memory_kind()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.PartitionSpec"><code class="docutils literal notranslate"><span class="pre">PartitionSpec</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#jax.sharding.Mesh"><code class="docutils literal notranslate"><span class="pre">Mesh</span></code></a></li> </ul> </li> </ul> </nav></div> </div></div> </div> <footer class="bd-footer-content"> <div class="bd-footer-content__inner container"> <div class="footer-item"> <p class="component-author"> By The JAX authors </p> </div> <div class="footer-item"> <p class="copyright"> © Copyright 2024, The JAX Authors. <br/> </p> </div> <div class="footer-item"> </div> <div class="footer-item"> </div> </div> </footer> </main> </div> </div> <!-- Scripts loaded after <body> so the DOM is not blocked --> <script src="_static/scripts/bootstrap.js?digest=5b4479735964841361fd"></script> <script src="_static/scripts/pydata-sphinx-theme.js?digest=5b4479735964841361fd"></script> <footer class="bd-footer"> </footer> </body> </html>