CINXE.COM
Layers
<!DOCTYPE html> <html lang="en" data-content_root="" > <head> <meta charset="utf-8" /> <meta name="viewport" content="width=device-width, initial-scale=1.0" /> <title>Layers</title> <script data-cfasync="false"> document.documentElement.dataset.mode = localStorage.getItem("mode") || ""; document.documentElement.dataset.theme = localStorage.getItem("theme") || ""; </script> <!-- Loaded before other Sphinx assets --> <link href="../../_static/styles/theme.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" /> <link href="../../_static/styles/bootstrap.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" /> <link href="../../_static/styles/pydata-sphinx-theme.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" /> <link href="../../_static/vendor/fontawesome/6.5.2/css/all.min.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" /> <link rel="preload" as="font" type="font/woff2" crossorigin href="../../_static/vendor/fontawesome/6.5.2/webfonts/fa-solid-900.woff2" /> <link rel="preload" as="font" type="font/woff2" crossorigin href="../../_static/vendor/fontawesome/6.5.2/webfonts/fa-brands-400.woff2" /> <link rel="preload" as="font" type="font/woff2" crossorigin href="../../_static/vendor/fontawesome/6.5.2/webfonts/fa-regular-400.woff2" /> <link rel="stylesheet" type="text/css" href="../../_static/pygments.css" /> <link rel="stylesheet" type="text/css" href="../../_static/styles/sphinx-book-theme.css" /> <link rel="stylesheet" type="text/css" href="../../_static/mystnb.4510f1fc1dee50b3e5859aac5469c37c29e427902b24a333a5f9fcb2f0b3ac41.css" /> <link rel="stylesheet" type="text/css" href="../../_static/sphinx-design.5ea377869091fd0449014c60fc090103.min.css" /> <link rel="stylesheet" type="text/css" href="../../_static/css/flax_theme.css" /> <!-- Pre-loaded scripts that we'll load fully later --> <link rel="preload" as="script" href="../../_static/scripts/bootstrap.js?digest=dfe6caa3a7d634c4db9b" /> <link rel="preload" as="script" href="../../_static/scripts/pydata-sphinx-theme.js?digest=dfe6caa3a7d634c4db9b" /> <script src="../../_static/vendor/fontawesome/6.5.2/js/all.min.js?digest=dfe6caa3a7d634c4db9b"></script> <script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script> <script src="../../_static/jquery.js"></script> <script src="../../_static/underscore.js"></script> <script src="../../_static/_sphinx_javascript_frameworks_compat.js"></script> <script src="../../_static/doctools.js"></script> <script src="../../_static/sphinx_highlight.js"></script> <script src="../../_static/scripts/sphinx-book-theme.js"></script> <script src="../../_static/design-tabs.js"></script> <script>window.MathJax = {"options": {"processHtmlClass": "tex2jax_process|mathjax_process|math|output_area"}}</script> <script defer="defer" src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> <script>DOCUMENTATION_OPTIONS.pagename = 'api_reference/flax.linen/layers';</script> <link rel="shortcut icon" href="../../_static/flax.png"/> <link rel="index" title="Index" href="../../genindex.html" /> <link rel="search" title="Search" href="../../search.html" /> <link rel="next" title="Activation functions" href="activation_functions.html" /> <link rel="prev" title="Init/Apply" href="init_apply.html" /> <meta name="viewport" content="width=device-width, initial-scale=1"/> <meta name="docsearch:language" content="en"/> <script async type="text/javascript" src="/_/static/javascript/readthedocs-addons.js"></script><meta name="readthedocs-project-slug" content="flax-linen" /><meta name="readthedocs-version-slug" content="latest" /><meta name="readthedocs-resolver-filename" content="/api_reference/flax.linen/layers.html" /><meta name="readthedocs-http-status" content="200" /></head> <body data-bs-spy="scroll" data-bs-target=".bd-toc-nav" data-offset="180" data-bs-root-margin="0px 0px -60%" data-default-mode=""> <div id="pst-skip-link" class="skip-link d-print-none"><a href="#main-content">Skip to main content</a></div> <div id="pst-scroll-pixel-helper"></div> <button type="button" class="btn rounded-pill" id="pst-back-to-top"> <i class="fa-solid fa-arrow-up"></i>Back to top</button> <input type="checkbox" class="sidebar-toggle" id="pst-primary-sidebar-checkbox"/> <label class="overlay overlay-primary" for="pst-primary-sidebar-checkbox"></label> <input type="checkbox" class="sidebar-toggle" id="pst-secondary-sidebar-checkbox"/> <label class="overlay overlay-secondary" for="pst-secondary-sidebar-checkbox"></label> <div class="search-button__wrapper"> <div class="search-button__overlay"></div> <div class="search-button__search-container"> <form class="bd-search d-flex align-items-center" action="../../search.html" method="get"> <i class="fa-solid fa-magnifying-glass"></i> <input type="search" class="form-control" name="q" id="search-input" placeholder="Search..." aria-label="Search..." autocomplete="off" autocorrect="off" autocapitalize="off" spellcheck="false"/> <span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd>K</kbd></span> </form></div> </div> <div class="pst-async-banner-revealer d-none"> <aside id="bd-header-version-warning" class="d-none d-print-none" aria-label="Version warning"></aside> </div> <aside class="bd-header-announcement" aria-label="Announcement"> <div class="bd-header-announcement__content"> <a href="https://flax.readthedocs.io/en/latest/index.html" style="text-decoration: none; color: white;" > This site covers the old Flax Linen API. <span style="color: lightgray;">[Explore the new <b>Flax NNX</b> API ✨]</span> </a> </div> </aside> <header class="bd-header navbar navbar-expand-lg bd-navbar d-print-none"> </header> <div class="bd-container"> <div class="bd-container__inner bd-page-width"> <div class="bd-sidebar-primary bd-sidebar"> <div class="sidebar-header-items sidebar-primary__section"> </div> <div class="sidebar-primary-items__start sidebar-primary__section"> <div class="sidebar-primary-item"> <a class="navbar-brand logo" href="../../index.html"> <img src="../../_static/flax.png" class="logo__image only-light" alt=" - Home"/> <script>document.write(`<img src="../../_static/flax.png" class="logo__image only-dark" alt=" - Home"/>`);</script> </a></div> <div class="sidebar-primary-item"> <script> document.write(` <button class="btn search-button-field search-button__button" title="Search" aria-label="Search" data-bs-placement="bottom" data-bs-toggle="tooltip"> <i class="fa-solid fa-magnifying-glass"></i> <span class="search-button__default-text">Search</span> <span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd class="kbd-shortcut__modifier">K</kbd></span> </button> `); </script></div> <div class="sidebar-primary-item"><nav class="bd-links bd-docs-nav" aria-label="Main"> <div class="bd-toc-item navbar-nav active"> <ul class="current nav bd-sidenav"> <li class="toctree-l1"><a class="reference internal" href="../../quick_start.html">Quick start</a></li> <li class="toctree-l1"><a class="reference internal" href="../../guides/flax_fundamentals/flax_basics.html">Flax Basics</a></li> <li class="toctree-l1 has-children"><a class="reference internal" href="../../guides/index.html">Guides</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l2 has-children"><a class="reference internal" href="../../guides/flax_fundamentals/index.html">Flax fundamentals</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l3"><a class="reference external" href="https://jax.readthedocs.io/en/latest/jax-101/index.html">JAX 101</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/flax_fundamentals/flax_basics.html">Flax Basics</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/flax_fundamentals/state_params.html">Managing Parameters and State</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/flax_fundamentals/setup_or_nncompact.html"><code class="docutils literal notranslate"><span class="pre">setup</span></code> vs <code class="docutils literal notranslate"><span class="pre">compact</span></code></a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/flax_fundamentals/arguments.html">Dealing with Flax Module arguments</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/flax_fundamentals/rng_guide.html">Randomness and PRNGs in Flax</a></li> </ul> </details></li> <li class="toctree-l2 has-children"><a class="reference internal" href="../../guides/data_preprocessing/index.html">Data preprocessing</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l3"><a class="reference internal" href="../../guides/data_preprocessing/full_eval.html">Processing the entire Dataset</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/data_preprocessing/loading_datasets.html">Loading datasets</a></li> </ul> </details></li> <li class="toctree-l2 has-children"><a class="reference internal" href="../../guides/training_techniques/index.html">Training techniques</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l3"><a class="reference internal" href="../../guides/training_techniques/batch_norm.html">Batch normalization</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/training_techniques/dropout.html">Dropout</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/training_techniques/lr_schedule.html">Learning rate scheduling</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/training_techniques/transfer_learning.html">Transfer learning</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/training_techniques/use_checkpointing.html">Save and load checkpoints</a></li> </ul> </details></li> <li class="toctree-l2 has-children"><a class="reference internal" href="../../guides/parallel_training/index.html">Parallel training</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l3"><a class="reference internal" href="../../guides/parallel_training/ensembling.html">Ensembling on multiple devices</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/parallel_training/flax_on_pjit.html">Scale up Flax Modules on multiple devices</a></li> </ul> </details></li> <li class="toctree-l2 has-children"><a class="reference internal" href="../../guides/model_inspection/index.html">Model inspection</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l3"><a class="reference internal" href="../../guides/model_inspection/model_surgery.html">Model surgery</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/model_inspection/extracting_intermediates.html">Extracting intermediate values</a></li> </ul> </details></li> <li class="toctree-l2 has-children"><a class="reference internal" href="../../guides/converting_and_upgrading/index.html">Converting and upgrading</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l3"><a class="reference internal" href="../../guides/converting_and_upgrading/haiku_migration_guide.html">Migrating from Haiku to Flax</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/converting_and_upgrading/convert_pytorch_to_flax.html">Convert PyTorch models to Flax</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/converting_and_upgrading/orbax_upgrade_guide.html">Migrate checkpointing to Orbax</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/converting_and_upgrading/optax_update_guide.html">Upgrading my codebase to Optax</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/converting_and_upgrading/linen_upgrade_guide.html">Upgrading my codebase to Linen</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/converting_and_upgrading/rnncell_upgrade_guide.html">RNNCellBase Upgrade Guide</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/converting_and_upgrading/regular_dict_upgrade_guide.html">Migrate to regular dicts</a></li> </ul> </details></li> <li class="toctree-l2 has-children"><a class="reference internal" href="../../guides/quantization/index.html">Quantization</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l3"><a class="reference internal" href="../../guides/quantization/fp8_basics.html">User Guide on Using FP8</a></li> </ul> </details></li> <li class="toctree-l2"><a class="reference internal" href="../../guides/flax_sharp_bits.html">The Sharp Bits</a></li> </ul> </details></li> <li class="toctree-l1 has-children"><a class="reference internal" href="../../examples/index.html">Examples</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l2"><a class="reference internal" href="../../examples/core_examples.html">Core examples</a></li> <li class="toctree-l2"><a class="reference internal" href="../../examples/google_research_examples.html">Google Research examples</a></li> <li class="toctree-l2"><a class="reference internal" href="../../examples/repositories_that_use_flax.html">Repositories that use Flax</a></li> <li class="toctree-l2"><a class="reference internal" href="../../examples/community_examples.html">Community examples</a></li> </ul> </details></li> <li class="toctree-l1"><a class="reference internal" href="../../glossary.html">Glossary</a></li> <li class="toctree-l1"><a class="reference internal" href="../../faq.html">Frequently Asked Questions (FAQ)</a></li> <li class="toctree-l1 has-children"><a class="reference internal" href="../../developer_notes/index.html">Developer notes</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l2"><a class="reference internal" href="../../developer_notes/module_lifecycle.html">The Flax Module lifecycle</a></li> <li class="toctree-l2"><a class="reference internal" href="../../developer_notes/lift.html">Lifted transformations</a></li> <li class="toctree-l2"><a class="reference external" href="https://github.com/google/flax/tree/main/docs/flip">FLIPs</a></li> </ul> </details></li> <li class="toctree-l1"><a class="reference internal" href="../../philosophy.html">The Flax philosophy</a></li> <li class="toctree-l1"><a class="reference internal" href="../../contributing.html">How to contribute</a></li> <li class="toctree-l1 current active has-children"><a class="reference internal" href="../index.html">API Reference</a><details open="open"><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul class="current"> <li class="toctree-l2"><a class="reference internal" href="../flax.config.html">flax.config package</a></li> <li class="toctree-l2"><a class="reference internal" href="../flax.core.frozen_dict.html">flax.core.frozen_dict package</a></li> <li class="toctree-l2"><a class="reference internal" href="../flax.cursor.html">flax.cursor package</a></li> <li class="toctree-l2"><a class="reference internal" href="../flax.errors.html">flax.errors package</a></li> <li class="toctree-l2"><a class="reference internal" href="../flax.jax_utils.html">flax.jax_utils package</a></li> <li class="toctree-l2 current active has-children"><a class="reference internal" href="index.html">flax.linen</a><details open="open"><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul class="current"> <li class="toctree-l3"><a class="reference internal" href="module.html">Module</a></li> <li class="toctree-l3"><a class="reference internal" href="init_apply.html">Init/Apply</a></li> <li class="toctree-l3 current active"><a class="current reference internal" href="#">Layers</a></li> <li class="toctree-l3"><a class="reference internal" href="activation_functions.html">Activation functions</a></li> <li class="toctree-l3"><a class="reference internal" href="initializers.html">Initializers</a></li> <li class="toctree-l3"><a class="reference internal" href="transformations.html">Transformations</a></li> <li class="toctree-l3"><a class="reference internal" href="inspection.html">Inspection</a></li> <li class="toctree-l3"><a class="reference internal" href="variable.html">Variable dictionary</a></li> <li class="toctree-l3"><a class="reference internal" href="spmd.html">SPMD</a></li> <li class="toctree-l3"><a class="reference internal" href="decorators.html">Decorators</a></li> <li class="toctree-l3"><a class="reference internal" href="profiling.html">Profiling</a></li> </ul> </details></li> <li class="toctree-l2"><a class="reference internal" href="../flax.serialization.html">flax.serialization package</a></li> <li class="toctree-l2"><a class="reference internal" href="../flax.struct.html">flax.struct package</a></li> <li class="toctree-l2"><a class="reference internal" href="../flax.traceback_util.html">flax.traceback_util package</a></li> <li class="toctree-l2"><a class="reference internal" href="../flax.training.html">flax.training package</a></li> <li class="toctree-l2"><a class="reference internal" href="../flax.traverse_util.html">flax.traverse_util package</a></li> </ul> </details></li> <li class="toctree-l1"><a class="reference external" href="https://flax.readthedocs.io/en/latest/index.html">Flax NNX</a></li> </ul> </div> </nav></div> </div> <div class="sidebar-primary-items__end sidebar-primary__section"> </div> <div id="rtd-footer-container"></div> </div> <main id="main-content" class="bd-main" role="main"> <div class="sbt-scroll-pixel-helper"></div> <div class="bd-content"> <div class="bd-article-container"> <div class="bd-header-article d-print-none"> <div class="header-article-items header-article__inner"> <div class="header-article-items__start"> <div class="header-article-item"><button class="sidebar-toggle primary-toggle btn btn-sm" title="Toggle primary sidebar" data-bs-placement="bottom" data-bs-toggle="tooltip"> <span class="fa-solid fa-bars"></span> </button></div> </div> <div class="header-article-items__end"> <div class="header-article-item"> <div class="article-header-buttons"> <a href="https://github.com/google/flax" target="_blank" class="btn btn-sm btn-source-repository-button" title="Source repository" data-bs-placement="bottom" data-bs-toggle="tooltip" > <span class="btn__icon-container"> <i class="fab fa-github"></i> </span> </a> <div class="dropdown dropdown-download-buttons"> <button class="btn dropdown-toggle" type="button" data-bs-toggle="dropdown" aria-expanded="false" aria-label="Download this page"> <i class="fas fa-download"></i> </button> <ul class="dropdown-menu"> <li><a href="../../_sources/api_reference/flax.linen/layers.rst" target="_blank" class="btn btn-sm btn-download-source-button dropdown-item" title="Download source file" data-bs-placement="left" data-bs-toggle="tooltip" > <span class="btn__icon-container"> <i class="fas fa-file"></i> </span> <span class="btn__text-container">.rst</span> </a> </li> <li> <button onclick="window.print()" class="btn btn-sm btn-download-pdf-button dropdown-item" title="Print to PDF" data-bs-placement="left" data-bs-toggle="tooltip" > <span class="btn__icon-container"> <i class="fas fa-file-pdf"></i> </span> <span class="btn__text-container">.pdf</span> </button> </li> </ul> </div> <button onclick="toggleFullScreen()" class="btn btn-sm btn-fullscreen-button" title="Fullscreen mode" data-bs-placement="bottom" data-bs-toggle="tooltip" > <span class="btn__icon-container"> <i class="fas fa-expand"></i> </span> </button> <script> document.write(` <button class="btn btn-sm nav-link pst-navbar-icon theme-switch-button" title="light/dark" aria-label="light/dark" data-bs-placement="bottom" data-bs-toggle="tooltip"> <i class="theme-switch fa-solid fa-sun fa-lg" data-mode="light"></i> <i class="theme-switch fa-solid fa-moon fa-lg" data-mode="dark"></i> <i class="theme-switch fa-solid fa-circle-half-stroke fa-lg" data-mode="auto"></i> </button> `); </script> <script> document.write(` <button class="btn btn-sm pst-navbar-icon search-button search-button__button" title="Search" aria-label="Search" data-bs-placement="bottom" data-bs-toggle="tooltip"> <i class="fa-solid fa-magnifying-glass fa-lg"></i> </button> `); </script> <button class="sidebar-toggle secondary-toggle btn btn-sm" title="Toggle secondary sidebar" data-bs-placement="bottom" data-bs-toggle="tooltip"> <span class="fa-solid fa-list"></span> </button> </div></div> </div> </div> </div> <div id="jb-print-docs-body" class="onlyprint"> <h1>Layers</h1> <!-- Table of contents --> <div id="print-main-content"> <div id="jb-print-toc"> <div> <h2> Contents </h2> </div> <nav aria-label="Page"> <ul class="visible nav section-nav flex-column"> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#linear-modules">Linear Modules</a><ul class="nav section-nav flex-column"> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Dense"><code class="docutils literal notranslate"><span class="pre">Dense</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Dense.features"><code class="docutils literal notranslate"><span class="pre">Dense.features</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Dense.use_bias"><code class="docutils literal notranslate"><span class="pre">Dense.use_bias</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Dense.dtype"><code class="docutils literal notranslate"><span class="pre">Dense.dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Dense.param_dtype"><code class="docutils literal notranslate"><span class="pre">Dense.param_dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Dense.precision"><code class="docutils literal notranslate"><span class="pre">Dense.precision</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Dense.kernel_init"><code class="docutils literal notranslate"><span class="pre">Dense.kernel_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Dense.bias_init"><code class="docutils literal notranslate"><span class="pre">Dense.bias_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Dense.__call__"><code class="docutils literal notranslate"><span class="pre">Dense.__call__()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.DenseGeneral"><code class="docutils literal notranslate"><span class="pre">DenseGeneral</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.DenseGeneral.features"><code class="docutils literal notranslate"><span class="pre">DenseGeneral.features</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.DenseGeneral.axis"><code class="docutils literal notranslate"><span class="pre">DenseGeneral.axis</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.DenseGeneral.batch_dims"><code class="docutils literal notranslate"><span class="pre">DenseGeneral.batch_dims</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.DenseGeneral.use_bias"><code class="docutils literal notranslate"><span class="pre">DenseGeneral.use_bias</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.DenseGeneral.dtype"><code class="docutils literal notranslate"><span class="pre">DenseGeneral.dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.DenseGeneral.param_dtype"><code class="docutils literal notranslate"><span class="pre">DenseGeneral.param_dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.DenseGeneral.kernel_init"><code class="docutils literal notranslate"><span class="pre">DenseGeneral.kernel_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.DenseGeneral.bias_init"><code class="docutils literal notranslate"><span class="pre">DenseGeneral.bias_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.DenseGeneral.precision"><code class="docutils literal notranslate"><span class="pre">DenseGeneral.precision</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.DenseGeneral.__call__"><code class="docutils literal notranslate"><span class="pre">DenseGeneral.__call__()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Conv"><code class="docutils literal notranslate"><span class="pre">Conv</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Conv.features"><code class="docutils literal notranslate"><span class="pre">Conv.features</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Conv.kernel_size"><code class="docutils literal notranslate"><span class="pre">Conv.kernel_size</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Conv.strides"><code class="docutils literal notranslate"><span class="pre">Conv.strides</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Conv.padding"><code class="docutils literal notranslate"><span class="pre">Conv.padding</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Conv.input_dilation"><code class="docutils literal notranslate"><span class="pre">Conv.input_dilation</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Conv.kernel_dilation"><code class="docutils literal notranslate"><span class="pre">Conv.kernel_dilation</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Conv.feature_group_count"><code class="docutils literal notranslate"><span class="pre">Conv.feature_group_count</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Conv.use_bias"><code class="docutils literal notranslate"><span class="pre">Conv.use_bias</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Conv.mask"><code class="docutils literal notranslate"><span class="pre">Conv.mask</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Conv.dtype"><code class="docutils literal notranslate"><span class="pre">Conv.dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Conv.param_dtype"><code class="docutils literal notranslate"><span class="pre">Conv.param_dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Conv.precision"><code class="docutils literal notranslate"><span class="pre">Conv.precision</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Conv.kernel_init"><code class="docutils literal notranslate"><span class="pre">Conv.kernel_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Conv.bias_init"><code class="docutils literal notranslate"><span class="pre">Conv.bias_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Conv.__call__"><code class="docutils literal notranslate"><span class="pre">Conv.__call__()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvTranspose"><code class="docutils literal notranslate"><span class="pre">ConvTranspose</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvTranspose.features"><code class="docutils literal notranslate"><span class="pre">ConvTranspose.features</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvTranspose.kernel_size"><code class="docutils literal notranslate"><span class="pre">ConvTranspose.kernel_size</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvTranspose.strides"><code class="docutils literal notranslate"><span class="pre">ConvTranspose.strides</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvTranspose.padding"><code class="docutils literal notranslate"><span class="pre">ConvTranspose.padding</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvTranspose.kernel_dilation"><code class="docutils literal notranslate"><span class="pre">ConvTranspose.kernel_dilation</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvTranspose.use_bias"><code class="docutils literal notranslate"><span class="pre">ConvTranspose.use_bias</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvTranspose.mask"><code class="docutils literal notranslate"><span class="pre">ConvTranspose.mask</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvTranspose.dtype"><code class="docutils literal notranslate"><span class="pre">ConvTranspose.dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvTranspose.param_dtype"><code class="docutils literal notranslate"><span class="pre">ConvTranspose.param_dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvTranspose.precision"><code class="docutils literal notranslate"><span class="pre">ConvTranspose.precision</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvTranspose.kernel_init"><code class="docutils literal notranslate"><span class="pre">ConvTranspose.kernel_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvTranspose.bias_init"><code class="docutils literal notranslate"><span class="pre">ConvTranspose.bias_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvTranspose.transpose_kernel"><code class="docutils literal notranslate"><span class="pre">ConvTranspose.transpose_kernel</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvTranspose.__call__"><code class="docutils literal notranslate"><span class="pre">ConvTranspose.__call__()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLocal"><code class="docutils literal notranslate"><span class="pre">ConvLocal</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLocal.features"><code class="docutils literal notranslate"><span class="pre">ConvLocal.features</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLocal.kernel_size"><code class="docutils literal notranslate"><span class="pre">ConvLocal.kernel_size</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLocal.strides"><code class="docutils literal notranslate"><span class="pre">ConvLocal.strides</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLocal.padding"><code class="docutils literal notranslate"><span class="pre">ConvLocal.padding</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLocal.input_dilation"><code class="docutils literal notranslate"><span class="pre">ConvLocal.input_dilation</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLocal.kernel_dilation"><code class="docutils literal notranslate"><span class="pre">ConvLocal.kernel_dilation</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLocal.feature_group_count"><code class="docutils literal notranslate"><span class="pre">ConvLocal.feature_group_count</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLocal.use_bias"><code class="docutils literal notranslate"><span class="pre">ConvLocal.use_bias</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLocal.mask"><code class="docutils literal notranslate"><span class="pre">ConvLocal.mask</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLocal.dtype"><code class="docutils literal notranslate"><span class="pre">ConvLocal.dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLocal.param_dtype"><code class="docutils literal notranslate"><span class="pre">ConvLocal.param_dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLocal.precision"><code class="docutils literal notranslate"><span class="pre">ConvLocal.precision</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLocal.kernel_init"><code class="docutils literal notranslate"><span class="pre">ConvLocal.kernel_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLocal.bias_init"><code class="docutils literal notranslate"><span class="pre">ConvLocal.bias_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLocal.__call__"><code class="docutils literal notranslate"><span class="pre">ConvLocal.__call__()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Einsum"><code class="docutils literal notranslate"><span class="pre">Einsum</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Einsum.shape"><code class="docutils literal notranslate"><span class="pre">Einsum.shape</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Einsum.einsum_str"><code class="docutils literal notranslate"><span class="pre">Einsum.einsum_str</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Einsum.use_bias"><code class="docutils literal notranslate"><span class="pre">Einsum.use_bias</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Einsum.dtype"><code class="docutils literal notranslate"><span class="pre">Einsum.dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Einsum.param_dtype"><code class="docutils literal notranslate"><span class="pre">Einsum.param_dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Einsum.precision"><code class="docutils literal notranslate"><span class="pre">Einsum.precision</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Einsum.kernel_init"><code class="docutils literal notranslate"><span class="pre">Einsum.kernel_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Einsum.bias_init"><code class="docutils literal notranslate"><span class="pre">Einsum.bias_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Einsum.__call__"><code class="docutils literal notranslate"><span class="pre">Einsum.__call__()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Embed"><code class="docutils literal notranslate"><span class="pre">Embed</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Embed.num_embeddings"><code class="docutils literal notranslate"><span class="pre">Embed.num_embeddings</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Embed.features"><code class="docutils literal notranslate"><span class="pre">Embed.features</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Embed.dtype"><code class="docutils literal notranslate"><span class="pre">Embed.dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Embed.param_dtype"><code class="docutils literal notranslate"><span class="pre">Embed.param_dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Embed.embedding_init"><code class="docutils literal notranslate"><span class="pre">Embed.embedding_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Embed.__call__"><code class="docutils literal notranslate"><span class="pre">Embed.__call__()</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Embed.attend"><code class="docutils literal notranslate"><span class="pre">Embed.attend()</span></code></a></li> </ul> </li> </ul> </li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#pooling">Pooling</a><ul class="nav section-nav flex-column"> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.max_pool"><code class="docutils literal notranslate"><span class="pre">max_pool()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.avg_pool"><code class="docutils literal notranslate"><span class="pre">avg_pool()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.pool"><code class="docutils literal notranslate"><span class="pre">pool()</span></code></a></li> </ul> </li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#normalization">Normalization</a><ul class="nav section-nav flex-column"> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.BatchNorm"><code class="docutils literal notranslate"><span class="pre">BatchNorm</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.BatchNorm.use_running_average"><code class="docutils literal notranslate"><span class="pre">BatchNorm.use_running_average</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.BatchNorm.axis"><code class="docutils literal notranslate"><span class="pre">BatchNorm.axis</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.BatchNorm.momentum"><code class="docutils literal notranslate"><span class="pre">BatchNorm.momentum</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.BatchNorm.epsilon"><code class="docutils literal notranslate"><span class="pre">BatchNorm.epsilon</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.BatchNorm.dtype"><code class="docutils literal notranslate"><span class="pre">BatchNorm.dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.BatchNorm.param_dtype"><code class="docutils literal notranslate"><span class="pre">BatchNorm.param_dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.BatchNorm.use_bias"><code class="docutils literal notranslate"><span class="pre">BatchNorm.use_bias</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.BatchNorm.use_scale"><code class="docutils literal notranslate"><span class="pre">BatchNorm.use_scale</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.BatchNorm.bias_init"><code class="docutils literal notranslate"><span class="pre">BatchNorm.bias_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.BatchNorm.scale_init"><code class="docutils literal notranslate"><span class="pre">BatchNorm.scale_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.BatchNorm.axis_name"><code class="docutils literal notranslate"><span class="pre">BatchNorm.axis_name</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.BatchNorm.axis_index_groups"><code class="docutils literal notranslate"><span class="pre">BatchNorm.axis_index_groups</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.BatchNorm.use_fast_variance"><code class="docutils literal notranslate"><span class="pre">BatchNorm.use_fast_variance</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.BatchNorm.__call__"><code class="docutils literal notranslate"><span class="pre">BatchNorm.__call__()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.LayerNorm"><code class="docutils literal notranslate"><span class="pre">LayerNorm</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.LayerNorm.epsilon"><code class="docutils literal notranslate"><span class="pre">LayerNorm.epsilon</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.LayerNorm.dtype"><code class="docutils literal notranslate"><span class="pre">LayerNorm.dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.LayerNorm.param_dtype"><code class="docutils literal notranslate"><span class="pre">LayerNorm.param_dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.LayerNorm.use_bias"><code class="docutils literal notranslate"><span class="pre">LayerNorm.use_bias</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.LayerNorm.use_scale"><code class="docutils literal notranslate"><span class="pre">LayerNorm.use_scale</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.LayerNorm.bias_init"><code class="docutils literal notranslate"><span class="pre">LayerNorm.bias_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.LayerNorm.scale_init"><code class="docutils literal notranslate"><span class="pre">LayerNorm.scale_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.LayerNorm.reduction_axes"><code class="docutils literal notranslate"><span class="pre">LayerNorm.reduction_axes</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.LayerNorm.feature_axes"><code class="docutils literal notranslate"><span class="pre">LayerNorm.feature_axes</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.LayerNorm.axis_name"><code class="docutils literal notranslate"><span class="pre">LayerNorm.axis_name</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.LayerNorm.axis_index_groups"><code class="docutils literal notranslate"><span class="pre">LayerNorm.axis_index_groups</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.LayerNorm.use_fast_variance"><code class="docutils literal notranslate"><span class="pre">LayerNorm.use_fast_variance</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.LayerNorm.__call__"><code class="docutils literal notranslate"><span class="pre">LayerNorm.__call__()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GroupNorm"><code class="docutils literal notranslate"><span class="pre">GroupNorm</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GroupNorm.num_groups"><code class="docutils literal notranslate"><span class="pre">GroupNorm.num_groups</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GroupNorm.group_size"><code class="docutils literal notranslate"><span class="pre">GroupNorm.group_size</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GroupNorm.epsilon"><code class="docutils literal notranslate"><span class="pre">GroupNorm.epsilon</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GroupNorm.dtype"><code class="docutils literal notranslate"><span class="pre">GroupNorm.dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GroupNorm.param_dtype"><code class="docutils literal notranslate"><span class="pre">GroupNorm.param_dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GroupNorm.use_bias"><code class="docutils literal notranslate"><span class="pre">GroupNorm.use_bias</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GroupNorm.use_scale"><code class="docutils literal notranslate"><span class="pre">GroupNorm.use_scale</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GroupNorm.bias_init"><code class="docutils literal notranslate"><span class="pre">GroupNorm.bias_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GroupNorm.scale_init"><code class="docutils literal notranslate"><span class="pre">GroupNorm.scale_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GroupNorm.reduction_axes"><code class="docutils literal notranslate"><span class="pre">GroupNorm.reduction_axes</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GroupNorm.axis_name"><code class="docutils literal notranslate"><span class="pre">GroupNorm.axis_name</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GroupNorm.axis_index_groups"><code class="docutils literal notranslate"><span class="pre">GroupNorm.axis_index_groups</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GroupNorm.use_fast_variance"><code class="docutils literal notranslate"><span class="pre">GroupNorm.use_fast_variance</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GroupNorm.__call__"><code class="docutils literal notranslate"><span class="pre">GroupNorm.__call__()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RMSNorm"><code class="docutils literal notranslate"><span class="pre">RMSNorm</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RMSNorm.epsilon"><code class="docutils literal notranslate"><span class="pre">RMSNorm.epsilon</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RMSNorm.dtype"><code class="docutils literal notranslate"><span class="pre">RMSNorm.dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RMSNorm.param_dtype"><code class="docutils literal notranslate"><span class="pre">RMSNorm.param_dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RMSNorm.use_scale"><code class="docutils literal notranslate"><span class="pre">RMSNorm.use_scale</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RMSNorm.scale_init"><code class="docutils literal notranslate"><span class="pre">RMSNorm.scale_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RMSNorm.reduction_axes"><code class="docutils literal notranslate"><span class="pre">RMSNorm.reduction_axes</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RMSNorm.feature_axes"><code class="docutils literal notranslate"><span class="pre">RMSNorm.feature_axes</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RMSNorm.axis_name"><code class="docutils literal notranslate"><span class="pre">RMSNorm.axis_name</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RMSNorm.axis_index_groups"><code class="docutils literal notranslate"><span class="pre">RMSNorm.axis_index_groups</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RMSNorm.use_fast_variance"><code class="docutils literal notranslate"><span class="pre">RMSNorm.use_fast_variance</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RMSNorm.__call__"><code class="docutils literal notranslate"><span class="pre">RMSNorm.__call__()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.InstanceNorm"><code class="docutils literal notranslate"><span class="pre">InstanceNorm</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.InstanceNorm.epsilon"><code class="docutils literal notranslate"><span class="pre">InstanceNorm.epsilon</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.InstanceNorm.dtype"><code class="docutils literal notranslate"><span class="pre">InstanceNorm.dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.InstanceNorm.param_dtype"><code class="docutils literal notranslate"><span class="pre">InstanceNorm.param_dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.InstanceNorm.use_bias"><code class="docutils literal notranslate"><span class="pre">InstanceNorm.use_bias</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.InstanceNorm.use_scale"><code class="docutils literal notranslate"><span class="pre">InstanceNorm.use_scale</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.InstanceNorm.bias_init"><code class="docutils literal notranslate"><span class="pre">InstanceNorm.bias_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.InstanceNorm.scale_init"><code class="docutils literal notranslate"><span class="pre">InstanceNorm.scale_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.InstanceNorm.feature_axes"><code class="docutils literal notranslate"><span class="pre">InstanceNorm.feature_axes</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.InstanceNorm.axis_name"><code class="docutils literal notranslate"><span class="pre">InstanceNorm.axis_name</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.InstanceNorm.axis_index_groups"><code class="docutils literal notranslate"><span class="pre">InstanceNorm.axis_index_groups</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.InstanceNorm.use_fast_variance"><code class="docutils literal notranslate"><span class="pre">InstanceNorm.use_fast_variance</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.InstanceNorm.__call__"><code class="docutils literal notranslate"><span class="pre">InstanceNorm.__call__()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.SpectralNorm"><code class="docutils literal notranslate"><span class="pre">SpectralNorm</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.SpectralNorm.layer_instance"><code class="docutils literal notranslate"><span class="pre">SpectralNorm.layer_instance</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.SpectralNorm.n_steps"><code class="docutils literal notranslate"><span class="pre">SpectralNorm.n_steps</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.SpectralNorm.epsilon"><code class="docutils literal notranslate"><span class="pre">SpectralNorm.epsilon</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.SpectralNorm.dtype"><code class="docutils literal notranslate"><span class="pre">SpectralNorm.dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.SpectralNorm.param_dtype"><code class="docutils literal notranslate"><span class="pre">SpectralNorm.param_dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.SpectralNorm.error_on_non_matrix"><code class="docutils literal notranslate"><span class="pre">SpectralNorm.error_on_non_matrix</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.SpectralNorm.collection_name"><code class="docutils literal notranslate"><span class="pre">SpectralNorm.collection_name</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.SpectralNorm.__call__"><code class="docutils literal notranslate"><span class="pre">SpectralNorm.__call__()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.WeightNorm"><code class="docutils literal notranslate"><span class="pre">WeightNorm</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.WeightNorm.layer_instance"><code class="docutils literal notranslate"><span class="pre">WeightNorm.layer_instance</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.WeightNorm.epsilon"><code class="docutils literal notranslate"><span class="pre">WeightNorm.epsilon</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.WeightNorm.dtype"><code class="docutils literal notranslate"><span class="pre">WeightNorm.dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.WeightNorm.param_dtype"><code class="docutils literal notranslate"><span class="pre">WeightNorm.param_dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.WeightNorm.use_scale"><code class="docutils literal notranslate"><span class="pre">WeightNorm.use_scale</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.WeightNorm.scale_init"><code class="docutils literal notranslate"><span class="pre">WeightNorm.scale_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.WeightNorm.feature_axes"><code class="docutils literal notranslate"><span class="pre">WeightNorm.feature_axes</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.WeightNorm.variable_filter"><code class="docutils literal notranslate"><span class="pre">WeightNorm.variable_filter</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.WeightNorm.__call__"><code class="docutils literal notranslate"><span class="pre">WeightNorm.__call__()</span></code></a></li> </ul> </li> </ul> </li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#combinators">Combinators</a><ul class="nav section-nav flex-column"> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Sequential"><code class="docutils literal notranslate"><span class="pre">Sequential</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Sequential.layers"><code class="docutils literal notranslate"><span class="pre">Sequential.layers</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Sequential.__call__"><code class="docutils literal notranslate"><span class="pre">Sequential.__call__()</span></code></a></li> </ul> </li> </ul> </li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#stochastic">Stochastic</a><ul class="nav section-nav flex-column"> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Dropout"><code class="docutils literal notranslate"><span class="pre">Dropout</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Dropout.rate"><code class="docutils literal notranslate"><span class="pre">Dropout.rate</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Dropout.broadcast_dims"><code class="docutils literal notranslate"><span class="pre">Dropout.broadcast_dims</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Dropout.deterministic"><code class="docutils literal notranslate"><span class="pre">Dropout.deterministic</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Dropout.rng_collection"><code class="docutils literal notranslate"><span class="pre">Dropout.rng_collection</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Dropout.__call__"><code class="docutils literal notranslate"><span class="pre">Dropout.__call__()</span></code></a></li> </ul> </li> </ul> </li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#attention">Attention</a><ul class="nav section-nav flex-column"> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadDotProductAttention"><code class="docutils literal notranslate"><span class="pre">MultiHeadDotProductAttention</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadDotProductAttention.num_heads"><code class="docutils literal notranslate"><span class="pre">MultiHeadDotProductAttention.num_heads</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadDotProductAttention.dtype"><code class="docutils literal notranslate"><span class="pre">MultiHeadDotProductAttention.dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadDotProductAttention.param_dtype"><code class="docutils literal notranslate"><span class="pre">MultiHeadDotProductAttention.param_dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadDotProductAttention.qkv_features"><code class="docutils literal notranslate"><span class="pre">MultiHeadDotProductAttention.qkv_features</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadDotProductAttention.out_features"><code class="docutils literal notranslate"><span class="pre">MultiHeadDotProductAttention.out_features</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadDotProductAttention.broadcast_dropout"><code class="docutils literal notranslate"><span class="pre">MultiHeadDotProductAttention.broadcast_dropout</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadDotProductAttention.dropout_rate"><code class="docutils literal notranslate"><span class="pre">MultiHeadDotProductAttention.dropout_rate</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadDotProductAttention.deterministic"><code class="docutils literal notranslate"><span class="pre">MultiHeadDotProductAttention.deterministic</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadDotProductAttention.precision"><code class="docutils literal notranslate"><span class="pre">MultiHeadDotProductAttention.precision</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadDotProductAttention.kernel_init"><code class="docutils literal notranslate"><span class="pre">MultiHeadDotProductAttention.kernel_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadDotProductAttention.out_kernel_init"><code class="docutils literal notranslate"><span class="pre">MultiHeadDotProductAttention.out_kernel_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadDotProductAttention.bias_init"><code class="docutils literal notranslate"><span class="pre">MultiHeadDotProductAttention.bias_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadDotProductAttention.out_bias_init"><code class="docutils literal notranslate"><span class="pre">MultiHeadDotProductAttention.out_bias_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadDotProductAttention.use_bias"><code class="docutils literal notranslate"><span class="pre">MultiHeadDotProductAttention.use_bias</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadDotProductAttention.attention_fn"><code class="docutils literal notranslate"><span class="pre">MultiHeadDotProductAttention.attention_fn</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadDotProductAttention.decode"><code class="docutils literal notranslate"><span class="pre">MultiHeadDotProductAttention.decode</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadDotProductAttention.normalize_qk"><code class="docutils literal notranslate"><span class="pre">MultiHeadDotProductAttention.normalize_qk</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadDotProductAttention.qk_attn_weights_einsum_cls"><code class="docutils literal notranslate"><span class="pre">MultiHeadDotProductAttention.qk_attn_weights_einsum_cls</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadDotProductAttention.attn_weights_value_einsum_cls"><code class="docutils literal notranslate"><span class="pre">MultiHeadDotProductAttention.attn_weights_value_einsum_cls</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadDotProductAttention.__call__"><code class="docutils literal notranslate"><span class="pre">MultiHeadDotProductAttention.__call__()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadAttention"><code class="docutils literal notranslate"><span class="pre">MultiHeadAttention</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadAttention.num_heads"><code class="docutils literal notranslate"><span class="pre">MultiHeadAttention.num_heads</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadAttention.dtype"><code class="docutils literal notranslate"><span class="pre">MultiHeadAttention.dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadAttention.param_dtype"><code class="docutils literal notranslate"><span class="pre">MultiHeadAttention.param_dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadAttention.qkv_features"><code class="docutils literal notranslate"><span class="pre">MultiHeadAttention.qkv_features</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadAttention.out_features"><code class="docutils literal notranslate"><span class="pre">MultiHeadAttention.out_features</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadAttention.broadcast_dropout"><code class="docutils literal notranslate"><span class="pre">MultiHeadAttention.broadcast_dropout</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadAttention.dropout_rate"><code class="docutils literal notranslate"><span class="pre">MultiHeadAttention.dropout_rate</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadAttention.deterministic"><code class="docutils literal notranslate"><span class="pre">MultiHeadAttention.deterministic</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadAttention.precision"><code class="docutils literal notranslate"><span class="pre">MultiHeadAttention.precision</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadAttention.kernel_init"><code class="docutils literal notranslate"><span class="pre">MultiHeadAttention.kernel_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadAttention.bias_init"><code class="docutils literal notranslate"><span class="pre">MultiHeadAttention.bias_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadAttention.use_bias"><code class="docutils literal notranslate"><span class="pre">MultiHeadAttention.use_bias</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadAttention.attention_fn"><code class="docutils literal notranslate"><span class="pre">MultiHeadAttention.attention_fn</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadAttention.decode"><code class="docutils literal notranslate"><span class="pre">MultiHeadAttention.decode</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadAttention.normalize_qk"><code class="docutils literal notranslate"><span class="pre">MultiHeadAttention.normalize_qk</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadAttention.__call__"><code class="docutils literal notranslate"><span class="pre">MultiHeadAttention.__call__()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.SelfAttention"><code class="docutils literal notranslate"><span class="pre">SelfAttention</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.SelfAttention.__call__"><code class="docutils literal notranslate"><span class="pre">SelfAttention.__call__()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.dot_product_attention_weights"><code class="docutils literal notranslate"><span class="pre">dot_product_attention_weights()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.dot_product_attention"><code class="docutils literal notranslate"><span class="pre">dot_product_attention()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.make_attention_mask"><code class="docutils literal notranslate"><span class="pre">make_attention_mask()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.make_causal_mask"><code class="docutils literal notranslate"><span class="pre">make_causal_mask()</span></code></a></li> </ul> </li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#recurrent">Recurrent</a><ul class="nav section-nav flex-column"> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RNNCellBase"><code class="docutils literal notranslate"><span class="pre">RNNCellBase</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RNNCellBase.__call__"><code class="docutils literal notranslate"><span class="pre">RNNCellBase.__call__()</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RNNCellBase.initialize_carry"><code class="docutils literal notranslate"><span class="pre">RNNCellBase.initialize_carry()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.LSTMCell"><code class="docutils literal notranslate"><span class="pre">LSTMCell</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.LSTMCell.features"><code class="docutils literal notranslate"><span class="pre">LSTMCell.features</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.LSTMCell.gate_fn"><code class="docutils literal notranslate"><span class="pre">LSTMCell.gate_fn</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.LSTMCell.activation_fn"><code class="docutils literal notranslate"><span class="pre">LSTMCell.activation_fn</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.LSTMCell.kernel_init"><code class="docutils literal notranslate"><span class="pre">LSTMCell.kernel_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.LSTMCell.recurrent_kernel_init"><code class="docutils literal notranslate"><span class="pre">LSTMCell.recurrent_kernel_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.LSTMCell.bias_init"><code class="docutils literal notranslate"><span class="pre">LSTMCell.bias_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.LSTMCell.dtype"><code class="docutils literal notranslate"><span class="pre">LSTMCell.dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.LSTMCell.param_dtype"><code class="docutils literal notranslate"><span class="pre">LSTMCell.param_dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.LSTMCell.__call__"><code class="docutils literal notranslate"><span class="pre">LSTMCell.__call__()</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.LSTMCell.initialize_carry"><code class="docutils literal notranslate"><span class="pre">LSTMCell.initialize_carry()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.OptimizedLSTMCell"><code class="docutils literal notranslate"><span class="pre">OptimizedLSTMCell</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.OptimizedLSTMCell.gate_fn"><code class="docutils literal notranslate"><span class="pre">OptimizedLSTMCell.gate_fn</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.OptimizedLSTMCell.activation_fn"><code class="docutils literal notranslate"><span class="pre">OptimizedLSTMCell.activation_fn</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.OptimizedLSTMCell.kernel_init"><code class="docutils literal notranslate"><span class="pre">OptimizedLSTMCell.kernel_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.OptimizedLSTMCell.recurrent_kernel_init"><code class="docutils literal notranslate"><span class="pre">OptimizedLSTMCell.recurrent_kernel_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.OptimizedLSTMCell.bias_init"><code class="docutils literal notranslate"><span class="pre">OptimizedLSTMCell.bias_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.OptimizedLSTMCell.dtype"><code class="docutils literal notranslate"><span class="pre">OptimizedLSTMCell.dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.OptimizedLSTMCell.param_dtype"><code class="docutils literal notranslate"><span class="pre">OptimizedLSTMCell.param_dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.OptimizedLSTMCell.__call__"><code class="docutils literal notranslate"><span class="pre">OptimizedLSTMCell.__call__()</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.OptimizedLSTMCell.initialize_carry"><code class="docutils literal notranslate"><span class="pre">OptimizedLSTMCell.initialize_carry()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLSTMCell"><code class="docutils literal notranslate"><span class="pre">ConvLSTMCell</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLSTMCell.features"><code class="docutils literal notranslate"><span class="pre">ConvLSTMCell.features</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLSTMCell.kernel_size"><code class="docutils literal notranslate"><span class="pre">ConvLSTMCell.kernel_size</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLSTMCell.strides"><code class="docutils literal notranslate"><span class="pre">ConvLSTMCell.strides</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLSTMCell.padding"><code class="docutils literal notranslate"><span class="pre">ConvLSTMCell.padding</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLSTMCell.bias"><code class="docutils literal notranslate"><span class="pre">ConvLSTMCell.bias</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLSTMCell.dtype"><code class="docutils literal notranslate"><span class="pre">ConvLSTMCell.dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLSTMCell.param_dtype"><code class="docutils literal notranslate"><span class="pre">ConvLSTMCell.param_dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLSTMCell.__call__"><code class="docutils literal notranslate"><span class="pre">ConvLSTMCell.__call__()</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLSTMCell.initialize_carry"><code class="docutils literal notranslate"><span class="pre">ConvLSTMCell.initialize_carry()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.SimpleCell"><code class="docutils literal notranslate"><span class="pre">SimpleCell</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.SimpleCell.features"><code class="docutils literal notranslate"><span class="pre">SimpleCell.features</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.SimpleCell.activation_fn"><code class="docutils literal notranslate"><span class="pre">SimpleCell.activation_fn</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.SimpleCell.kernel_init"><code class="docutils literal notranslate"><span class="pre">SimpleCell.kernel_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.SimpleCell.recurrent_kernel_init"><code class="docutils literal notranslate"><span class="pre">SimpleCell.recurrent_kernel_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.SimpleCell.bias_init"><code class="docutils literal notranslate"><span class="pre">SimpleCell.bias_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.SimpleCell.dtype"><code class="docutils literal notranslate"><span class="pre">SimpleCell.dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.SimpleCell.param_dtype"><code class="docutils literal notranslate"><span class="pre">SimpleCell.param_dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.SimpleCell.residual"><code class="docutils literal notranslate"><span class="pre">SimpleCell.residual</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.SimpleCell.__call__"><code class="docutils literal notranslate"><span class="pre">SimpleCell.__call__()</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.SimpleCell.initialize_carry"><code class="docutils literal notranslate"><span class="pre">SimpleCell.initialize_carry()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GRUCell"><code class="docutils literal notranslate"><span class="pre">GRUCell</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GRUCell.features"><code class="docutils literal notranslate"><span class="pre">GRUCell.features</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GRUCell.gate_fn"><code class="docutils literal notranslate"><span class="pre">GRUCell.gate_fn</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GRUCell.activation_fn"><code class="docutils literal notranslate"><span class="pre">GRUCell.activation_fn</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GRUCell.kernel_init"><code class="docutils literal notranslate"><span class="pre">GRUCell.kernel_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GRUCell.recurrent_kernel_init"><code class="docutils literal notranslate"><span class="pre">GRUCell.recurrent_kernel_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GRUCell.bias_init"><code class="docutils literal notranslate"><span class="pre">GRUCell.bias_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GRUCell.dtype"><code class="docutils literal notranslate"><span class="pre">GRUCell.dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GRUCell.param_dtype"><code class="docutils literal notranslate"><span class="pre">GRUCell.param_dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GRUCell.__call__"><code class="docutils literal notranslate"><span class="pre">GRUCell.__call__()</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GRUCell.initialize_carry"><code class="docutils literal notranslate"><span class="pre">GRUCell.initialize_carry()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MGUCell"><code class="docutils literal notranslate"><span class="pre">MGUCell</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MGUCell.features"><code class="docutils literal notranslate"><span class="pre">MGUCell.features</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MGUCell.gate_fn"><code class="docutils literal notranslate"><span class="pre">MGUCell.gate_fn</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MGUCell.activation_fn"><code class="docutils literal notranslate"><span class="pre">MGUCell.activation_fn</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MGUCell.kernel_init"><code class="docutils literal notranslate"><span class="pre">MGUCell.kernel_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MGUCell.recurrent_kernel_init"><code class="docutils literal notranslate"><span class="pre">MGUCell.recurrent_kernel_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MGUCell.forget_bias_init"><code class="docutils literal notranslate"><span class="pre">MGUCell.forget_bias_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MGUCell.activation_bias_init"><code class="docutils literal notranslate"><span class="pre">MGUCell.activation_bias_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MGUCell.dtype"><code class="docutils literal notranslate"><span class="pre">MGUCell.dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MGUCell.param_dtype"><code class="docutils literal notranslate"><span class="pre">MGUCell.param_dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MGUCell.reset_gate"><code class="docutils literal notranslate"><span class="pre">MGUCell.reset_gate</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MGUCell.__call__"><code class="docutils literal notranslate"><span class="pre">MGUCell.__call__()</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MGUCell.initialize_carry"><code class="docutils literal notranslate"><span class="pre">MGUCell.initialize_carry()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RNN"><code class="docutils literal notranslate"><span class="pre">RNN</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RNN.cell"><code class="docutils literal notranslate"><span class="pre">RNN.cell</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RNN.time_major"><code class="docutils literal notranslate"><span class="pre">RNN.time_major</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RNN.return_carry"><code class="docutils literal notranslate"><span class="pre">RNN.return_carry</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RNN.reverse"><code class="docutils literal notranslate"><span class="pre">RNN.reverse</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RNN.keep_order"><code class="docutils literal notranslate"><span class="pre">RNN.keep_order</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RNN.unroll"><code class="docutils literal notranslate"><span class="pre">RNN.unroll</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RNN.variable_axes"><code class="docutils literal notranslate"><span class="pre">RNN.variable_axes</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RNN.variable_broadcast"><code class="docutils literal notranslate"><span class="pre">RNN.variable_broadcast</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RNN.variable_carry"><code class="docutils literal notranslate"><span class="pre">RNN.variable_carry</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RNN.split_rngs"><code class="docutils literal notranslate"><span class="pre">RNN.split_rngs</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RNN.__call__"><code class="docutils literal notranslate"><span class="pre">RNN.__call__()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Bidirectional"><code class="docutils literal notranslate"><span class="pre">Bidirectional</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Bidirectional.__call__"><code class="docutils literal notranslate"><span class="pre">Bidirectional.__call__()</span></code></a></li> </ul> </li> </ul> </li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#batchapply">BatchApply</a><ul class="nav section-nav flex-column"> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.BatchApply"><code class="docutils literal notranslate"><span class="pre">BatchApply</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.BatchApply.__call__"><code class="docutils literal notranslate"><span class="pre">BatchApply.__call__()</span></code></a></li> </ul> </li> </ul> </li> </ul> </nav> </div> </div> </div> <div id="searchbox"></div> <article class="bd-article"> <div class="section" id="layers"> <h1>Layers<a class="headerlink" href="#layers" title="Permalink to this heading">#</a></h1> <div class="section" id="linear-modules"> <h2>Linear Modules<a class="headerlink" href="#linear-modules" title="Permalink to this heading">#</a></h2> <div class="docutils container"> <dl class="py class"> <dt class="sig sig-object py" id="flax.linen.Dense"> <em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">flax.linen.</span></span><span class="sig-name descname"><span class="pre">Dense</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">features</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">use_bias=True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">param_dtype=<class</span> <span class="pre">'jax.numpy.float32'></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">precision=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">kernel_init=<function</span> <span class="pre">variance_scaling.<locals>.init></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">bias_init=<function</span> <span class="pre">zeros></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dot_general=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dot_general_cls=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">parent=<flax.linen.module._Sentinel</span> <span class="pre">object></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">name=None</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/linear.html#Dense"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.Dense" title="Permalink to this definition">#</a></dt> <dd><p>A linear transformation applied over the last dimension of the input.</p> <p>Example usage:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="kn">import</span> <span class="nn">flax.linen</span> <span class="k">as</span> <span class="nn">nn</span> <span class="gp">>>> </span><span class="kn">import</span> <span class="nn">jax</span><span class="o">,</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span> <span class="gp">>>> </span><span class="n">layer</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">features</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">params</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">)))</span> <span class="gp">>>> </span><span class="n">jax</span><span class="o">.</span><span class="n">tree_util</span><span class="o">.</span><span class="n">tree_map</span><span class="p">(</span><span class="n">jnp</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">params</span><span class="p">)</span> <span class="go">{'params': {'bias': (4,), 'kernel': (3, 4)}}</span> </pre></div> </div> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.Dense.features"> <span class="sig-name descname"><span class="pre">features</span></span><a class="headerlink" href="#flax.linen.Dense.features" title="Permalink to this definition">#</a></dt> <dd><p>the number of output features.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>int</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.Dense.use_bias"> <span class="sig-name descname"><span class="pre">use_bias</span></span><a class="headerlink" href="#flax.linen.Dense.use_bias" title="Permalink to this definition">#</a></dt> <dd><p>whether to add a bias to the output (default: True).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>bool</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.Dense.dtype"> <span class="sig-name descname"><span class="pre">dtype</span></span><a class="headerlink" href="#flax.linen.Dense.dtype" title="Permalink to this definition">#</a></dt> <dd><p>the dtype of the computation (default: infer from input and params).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Optional[Union[str, type[<em>Any</em>], numpy.dtype, jax._src.typing.SupportsDType, Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.Dense.param_dtype"> <span class="sig-name descname"><span class="pre">param_dtype</span></span><a class="headerlink" href="#flax.linen.Dense.param_dtype" title="Permalink to this definition">#</a></dt> <dd><p>the dtype passed to parameter initializers (default: float32).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[str, type[<em>Any</em>], numpy.dtype, jax._src.typing.SupportsDType, Any]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.Dense.precision"> <span class="sig-name descname"><span class="pre">precision</span></span><a class="headerlink" href="#flax.linen.Dense.precision" title="Permalink to this definition">#</a></dt> <dd><p>numerical precision of the computation see <code class="docutils literal notranslate"><span class="pre">jax.lax.Precision</span></code> for details.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[None, str, jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.Dense.kernel_init"> <span class="sig-name descname"><span class="pre">kernel_init</span></span><a class="headerlink" href="#flax.linen.Dense.kernel_init" title="Permalink to this definition">#</a></dt> <dd><p>initializer function for the weight matrix.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.Dense.bias_init"> <span class="sig-name descname"><span class="pre">bias_init</span></span><a class="headerlink" href="#flax.linen.Dense.bias_init" title="Permalink to this definition">#</a></dt> <dd><p>initializer function for the bias.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]</p> </dd> </dl> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.Dense.__call__"> <span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">inputs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/linear.html#Dense.__call__"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.Dense.__call__" title="Permalink to this definition">#</a></dt> <dd><p>Applies a linear transformation to the inputs along the last dimension.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><p><strong>inputs</strong> – The nd-array to be transformed.</p> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>The transformed input.</p> </dd> </dl> </dd></dl> <p class="rubric">Methods</p> <div class="pst-scrollable-table-container"><table class="autosummary longtable table autosummary"> <colgroup> <col style="width: 10%" /> <col style="width: 90%" /> </colgroup> <tbody> </tbody> </table> </div> </dd></dl> </div> <div class="docutils container"> <dl class="py class"> <dt class="sig sig-object py" id="flax.linen.DenseGeneral"> <em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">flax.linen.</span></span><span class="sig-name descname"><span class="pre">DenseGeneral</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">features</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">axis=-1</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch_dims=()</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">use_bias=True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">param_dtype=<class</span> <span class="pre">'jax.numpy.float32'></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">kernel_init=<function</span> <span class="pre">variance_scaling.<locals>.init></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">bias_init=<function</span> <span class="pre">zeros></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">precision=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dot_general=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dot_general_cls=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">parent=<flax.linen.module._Sentinel</span> <span class="pre">object></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">name=None</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/linear.html#DenseGeneral"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.DenseGeneral" title="Permalink to this definition">#</a></dt> <dd><p>A linear transformation with flexible axes.</p> <p>Example usage:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="kn">import</span> <span class="nn">flax.linen</span> <span class="k">as</span> <span class="nn">nn</span> <span class="gp">>>> </span><span class="kn">import</span> <span class="nn">jax</span><span class="o">,</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span> <span class="gp">>>> </span><span class="c1"># equivalent to `nn.Dense(features=4)`</span> <span class="gp">>>> </span><span class="n">layer</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">DenseGeneral</span><span class="p">(</span><span class="n">features</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span> <span class="gp">>>> </span><span class="c1"># output features (4, 5)</span> <span class="gp">>>> </span><span class="n">layer</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">DenseGeneral</span><span class="p">(</span><span class="n">features</span><span class="o">=</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">5</span><span class="p">))</span> <span class="gp">>>> </span><span class="n">params</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">)))</span> <span class="gp">>>> </span><span class="n">jax</span><span class="o">.</span><span class="n">tree_util</span><span class="o">.</span><span class="n">tree_map</span><span class="p">(</span><span class="n">jnp</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">params</span><span class="p">)</span> <span class="go">{'params': {'bias': (4, 5), 'kernel': (3, 4, 5)}}</span> <span class="gp">>>> </span><span class="c1"># apply transformation on the the second and last axes</span> <span class="gp">>>> </span><span class="n">layer</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">DenseGeneral</span><span class="p">(</span><span class="n">features</span><span class="o">=</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">5</span><span class="p">),</span> <span class="n">axis</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">))</span> <span class="gp">>>> </span><span class="n">params</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">7</span><span class="p">)))</span> <span class="gp">>>> </span><span class="n">jax</span><span class="o">.</span><span class="n">tree_util</span><span class="o">.</span><span class="n">tree_map</span><span class="p">(</span><span class="n">jnp</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">params</span><span class="p">)</span> <span class="go">{'params': {'bias': (4, 5), 'kernel': (3, 7, 4, 5)}}</span> </pre></div> </div> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.DenseGeneral.features"> <span class="sig-name descname"><span class="pre">features</span></span><a class="headerlink" href="#flax.linen.DenseGeneral.features" title="Permalink to this definition">#</a></dt> <dd><p>int or tuple with number of output features.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>int | collections.abc.Sequence[int]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.DenseGeneral.axis"> <span class="sig-name descname"><span class="pre">axis</span></span><a class="headerlink" href="#flax.linen.DenseGeneral.axis" title="Permalink to this definition">#</a></dt> <dd><p>int or tuple with axes to apply the transformation on. For instance, (-2, -1) will apply the transformation to the last two axes.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>int | collections.abc.Sequence[int]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.DenseGeneral.batch_dims"> <span class="sig-name descname"><span class="pre">batch_dims</span></span><a class="headerlink" href="#flax.linen.DenseGeneral.batch_dims" title="Permalink to this definition">#</a></dt> <dd><p>tuple with batch axes.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>collections.abc.Sequence[int]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.DenseGeneral.use_bias"> <span class="sig-name descname"><span class="pre">use_bias</span></span><a class="headerlink" href="#flax.linen.DenseGeneral.use_bias" title="Permalink to this definition">#</a></dt> <dd><p>whether to add a bias to the output (default: True).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>bool</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.DenseGeneral.dtype"> <span class="sig-name descname"><span class="pre">dtype</span></span><a class="headerlink" href="#flax.linen.DenseGeneral.dtype" title="Permalink to this definition">#</a></dt> <dd><p>the dtype of the computation (default: infer from input and params).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Optional[Union[str, type[<em>Any</em>], numpy.dtype, jax._src.typing.SupportsDType, Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.DenseGeneral.param_dtype"> <span class="sig-name descname"><span class="pre">param_dtype</span></span><a class="headerlink" href="#flax.linen.DenseGeneral.param_dtype" title="Permalink to this definition">#</a></dt> <dd><p>the dtype passed to parameter initializers (default: float32).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[str, type[<em>Any</em>], numpy.dtype, jax._src.typing.SupportsDType, Any]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.DenseGeneral.kernel_init"> <span class="sig-name descname"><span class="pre">kernel_init</span></span><a class="headerlink" href="#flax.linen.DenseGeneral.kernel_init" title="Permalink to this definition">#</a></dt> <dd><p>initializer function for the weight matrix.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.DenseGeneral.bias_init"> <span class="sig-name descname"><span class="pre">bias_init</span></span><a class="headerlink" href="#flax.linen.DenseGeneral.bias_init" title="Permalink to this definition">#</a></dt> <dd><p>initializer function for the bias.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.DenseGeneral.precision"> <span class="sig-name descname"><span class="pre">precision</span></span><a class="headerlink" href="#flax.linen.DenseGeneral.precision" title="Permalink to this definition">#</a></dt> <dd><p>numerical precision of the computation see <code class="docutils literal notranslate"><span class="pre">jax.lax.Precision</span></code> for details.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[None, str, jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]</p> </dd> </dl> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.DenseGeneral.__call__"> <span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">inputs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/linear.html#DenseGeneral.__call__"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.DenseGeneral.__call__" title="Permalink to this definition">#</a></dt> <dd><p>Applies a linear transformation to the inputs along multiple dimensions.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><p><strong>inputs</strong> – The nd-array to be transformed.</p> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>The transformed input.</p> </dd> </dl> </dd></dl> <p class="rubric">Methods</p> <div class="pst-scrollable-table-container"><table class="autosummary longtable table autosummary"> <colgroup> <col style="width: 10%" /> <col style="width: 90%" /> </colgroup> <tbody> </tbody> </table> </div> </dd></dl> </div> <div class="docutils container"> <dl class="py class"> <dt class="sig sig-object py" id="flax.linen.Conv"> <em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">flax.linen.</span></span><span class="sig-name descname"><span class="pre">Conv</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">features</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">kernel_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">strides=1</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">padding='SAME'</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">input_dilation=1</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">kernel_dilation=1</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">feature_group_count=1</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">use_bias=True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">mask=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">param_dtype=<class</span> <span class="pre">'jax.numpy.float32'></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">precision=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">kernel_init=<function</span> <span class="pre">variance_scaling.<locals>.init></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">bias_init=<function</span> <span class="pre">zeros></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">conv_general_dilated=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">conv_general_dilated_cls=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">parent=<flax.linen.module._Sentinel</span> <span class="pre">object></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">name=None</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/linear.html#Conv"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.Conv" title="Permalink to this definition">#</a></dt> <dd><p>Convolution Module wrapping <code class="docutils literal notranslate"><span class="pre">lax.conv_general_dilated</span></code>.</p> <p>Example usage:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="kn">import</span> <span class="nn">flax.linen</span> <span class="k">as</span> <span class="nn">nn</span> <span class="gp">>>> </span><span class="kn">import</span> <span class="nn">jax</span><span class="o">,</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span> <span class="gp">>>> </span><span class="c1"># valid padding</span> <span class="gp">>>> </span><span class="n">layer</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv</span><span class="p">(</span><span class="n">features</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,),</span> <span class="n">padding</span><span class="o">=</span><span class="s1">'VALID'</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">out</span><span class="p">,</span> <span class="n">variables</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">init_with_output</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">3</span><span class="p">)))</span> <span class="gp">>>> </span><span class="n">jax</span><span class="o">.</span><span class="n">tree_util</span><span class="o">.</span><span class="n">tree_map</span><span class="p">(</span><span class="n">jnp</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">variables</span><span class="p">)</span> <span class="go">{'params': {'bias': (4,), 'kernel': (3, 3, 4)}}</span> <span class="gp">>>> </span><span class="n">out</span><span class="o">.</span><span class="n">shape</span> <span class="go">(1, 6, 4)</span> <span class="gp">>>> </span><span class="c1"># circular padding with stride 2</span> <span class="gp">>>> </span><span class="n">layer</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv</span><span class="p">(</span><span class="n">features</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">strides</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s1">'CIRCULAR'</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">out</span><span class="p">,</span> <span class="n">variables</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">init_with_output</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">3</span><span class="p">)))</span> <span class="gp">>>> </span><span class="n">jax</span><span class="o">.</span><span class="n">tree_util</span><span class="o">.</span><span class="n">tree_map</span><span class="p">(</span><span class="n">jnp</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">variables</span><span class="p">)</span> <span class="go">{'params': {'bias': (4,), 'kernel': (3, 3, 3, 4)}}</span> <span class="gp">>>> </span><span class="n">out</span><span class="o">.</span><span class="n">shape</span> <span class="go">(1, 4, 4)</span> <span class="gp">>>> </span><span class="c1"># apply lower triangle mask</span> <span class="gp">>>> </span><span class="n">mask</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">tril</span><span class="p">(</span><span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">)))</span> <span class="gp">>>> </span><span class="n">layer</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv</span><span class="p">(</span><span class="n">features</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,),</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s1">'VALID'</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">variables</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">3</span><span class="p">)))</span> </pre></div> </div> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.Conv.features"> <span class="sig-name descname"><span class="pre">features</span></span><a class="headerlink" href="#flax.linen.Conv.features" title="Permalink to this definition">#</a></dt> <dd><p>number of convolution filters.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>int</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.Conv.kernel_size"> <span class="sig-name descname"><span class="pre">kernel_size</span></span><a class="headerlink" href="#flax.linen.Conv.kernel_size" title="Permalink to this definition">#</a></dt> <dd><p>shape of the convolutional kernel. An integer will be interpreted as a tuple of the single integer.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>int | collections.abc.Sequence[int]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.Conv.strides"> <span class="sig-name descname"><span class="pre">strides</span></span><a class="headerlink" href="#flax.linen.Conv.strides" title="Permalink to this definition">#</a></dt> <dd><p>an integer or a sequence of <cite>n</cite> integers, representing the inter-window strides (default: 1).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>None | int | collections.abc.Sequence[int]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.Conv.padding"> <span class="sig-name descname"><span class="pre">padding</span></span><a class="headerlink" href="#flax.linen.Conv.padding" title="Permalink to this definition">#</a></dt> <dd><p>either the string <code class="docutils literal notranslate"><span class="pre">'SAME'</span></code>, the string <code class="docutils literal notranslate"><span class="pre">'VALID'</span></code>, the string <code class="docutils literal notranslate"><span class="pre">'CIRCULAR'</span></code> (periodic boundary conditions), or a sequence of <code class="docutils literal notranslate"><span class="pre">n</span></code> <code class="docutils literal notranslate"><span class="pre">(low,</span> <span class="pre">high)</span></code> integer pairs that give the padding to apply before and after each spatial dimension. A single int is interpreted as applying the same padding in all dims and assign a single int in a sequence causes the same padding to be used on both sides. <code class="docutils literal notranslate"><span class="pre">'CAUSAL'</span></code> padding for a 1D convolution will left-pad the convolution axis, resulting in same-sized output.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[str, int, collections.abc.Sequence[Union[int, tuple[int, int]]]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.Conv.input_dilation"> <span class="sig-name descname"><span class="pre">input_dilation</span></span><a class="headerlink" href="#flax.linen.Conv.input_dilation" title="Permalink to this definition">#</a></dt> <dd><p>an integer or a sequence of <code class="docutils literal notranslate"><span class="pre">n</span></code> integers, giving the dilation factor to apply in each spatial dimension of <code class="docutils literal notranslate"><span class="pre">inputs</span></code> (default: 1). Convolution with input dilation <code class="docutils literal notranslate"><span class="pre">d</span></code> is equivalent to transposed convolution with stride <code class="docutils literal notranslate"><span class="pre">d</span></code>.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>None | int | collections.abc.Sequence[int]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.Conv.kernel_dilation"> <span class="sig-name descname"><span class="pre">kernel_dilation</span></span><a class="headerlink" href="#flax.linen.Conv.kernel_dilation" title="Permalink to this definition">#</a></dt> <dd><p>an integer or a sequence of <code class="docutils literal notranslate"><span class="pre">n</span></code> integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel (default: 1). Convolution with kernel dilation is also known as ‘atrous convolution’.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>None | int | collections.abc.Sequence[int]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.Conv.feature_group_count"> <span class="sig-name descname"><span class="pre">feature_group_count</span></span><a class="headerlink" href="#flax.linen.Conv.feature_group_count" title="Permalink to this definition">#</a></dt> <dd><p>integer, default 1. If specified divides the input features into groups.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>int</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.Conv.use_bias"> <span class="sig-name descname"><span class="pre">use_bias</span></span><a class="headerlink" href="#flax.linen.Conv.use_bias" title="Permalink to this definition">#</a></dt> <dd><p>whether to add a bias to the output (default: True).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>bool</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.Conv.mask"> <span class="sig-name descname"><span class="pre">mask</span></span><a class="headerlink" href="#flax.linen.Conv.mask" title="Permalink to this definition">#</a></dt> <dd><p>Optional mask for the weights during masked convolution. The mask must be the same shape as the convolution weight matrix.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Optional[Union[jax.Array, Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.Conv.dtype"> <span class="sig-name descname"><span class="pre">dtype</span></span><a class="headerlink" href="#flax.linen.Conv.dtype" title="Permalink to this definition">#</a></dt> <dd><p>the dtype of the computation (default: infer from input and params).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Optional[Union[str, type[<em>Any</em>], numpy.dtype, jax._src.typing.SupportsDType, Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.Conv.param_dtype"> <span class="sig-name descname"><span class="pre">param_dtype</span></span><a class="headerlink" href="#flax.linen.Conv.param_dtype" title="Permalink to this definition">#</a></dt> <dd><p>the dtype passed to parameter initializers (default: float32).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[str, type[<em>Any</em>], numpy.dtype, jax._src.typing.SupportsDType, Any]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.Conv.precision"> <span class="sig-name descname"><span class="pre">precision</span></span><a class="headerlink" href="#flax.linen.Conv.precision" title="Permalink to this definition">#</a></dt> <dd><p>numerical precision of the computation see <a href="#id1"><span class="problematic" id="id2">``</span></a>jax.lax.Precision` for details.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[None, str, jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.Conv.kernel_init"> <span class="sig-name descname"><span class="pre">kernel_init</span></span><a class="headerlink" href="#flax.linen.Conv.kernel_init" title="Permalink to this definition">#</a></dt> <dd><p>initializer for the convolutional kernel.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.Conv.bias_init"> <span class="sig-name descname"><span class="pre">bias_init</span></span><a class="headerlink" href="#flax.linen.Conv.bias_init" title="Permalink to this definition">#</a></dt> <dd><p>initializer for the bias.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]</p> </dd> </dl> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.Conv.__call__"> <span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">inputs</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#flax.linen.Conv.__call__" title="Permalink to this definition">#</a></dt> <dd><p>Applies a (potentially unshared) convolution to the inputs.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><p><strong>inputs</strong> – input data with dimensions <code class="docutils literal notranslate"><span class="pre">(*batch_dims,</span> <span class="pre">spatial_dims...,</span> <span class="pre">features)</span></code>. This is the channels-last convention, i.e. NHWC for a 2d convolution and NDHWC for a 3D convolution. Note: this is different from the input convention used by <code class="docutils literal notranslate"><span class="pre">lax.conv_general_dilated</span></code>, which puts the spatial dimensions last. Note: If the input has more than 1 batch dimension, all batch dimensions are flattened into a single dimension for the convolution and restored before returning. In some cases directly vmap’ing the layer may yield better performance than this default flattening approach. If the input lacks a batch dimension it will be added for the convolution and removed n return, an allowance made to enable writing single-example code.</p> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>The convolved data.</p> </dd> </dl> </dd></dl> <p class="rubric">Methods</p> <div class="pst-scrollable-table-container"><table class="autosummary longtable table autosummary"> <colgroup> <col style="width: 10%" /> <col style="width: 90%" /> </colgroup> <tbody> </tbody> </table> </div> </dd></dl> </div> <div class="docutils container"> <dl class="py class"> <dt class="sig sig-object py" id="flax.linen.ConvTranspose"> <em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">flax.linen.</span></span><span class="sig-name descname"><span class="pre">ConvTranspose</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">features</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">kernel_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">strides=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">padding='SAME'</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">kernel_dilation=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">use_bias=True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">mask=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">param_dtype=<class</span> <span class="pre">'jax.numpy.float32'></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">precision=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">kernel_init=<function</span> <span class="pre">variance_scaling.<locals>.init></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">bias_init=<function</span> <span class="pre">zeros></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">transpose_kernel=False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">parent=<flax.linen.module._Sentinel</span> <span class="pre">object></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">name=None</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/linear.html#ConvTranspose"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.ConvTranspose" title="Permalink to this definition">#</a></dt> <dd><p>Convolution Module wrapping <code class="docutils literal notranslate"><span class="pre">lax.conv_transpose</span></code>.</p> <p>Example usage:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="kn">import</span> <span class="nn">flax.linen</span> <span class="k">as</span> <span class="nn">nn</span> <span class="gp">>>> </span><span class="kn">import</span> <span class="nn">jax</span><span class="o">,</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span> <span class="gp">>>> </span><span class="c1"># valid padding</span> <span class="gp">>>> </span><span class="n">layer</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ConvTranspose</span><span class="p">(</span><span class="n">features</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,),</span> <span class="n">padding</span><span class="o">=</span><span class="s1">'VALID'</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">out</span><span class="p">,</span> <span class="n">variables</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">init_with_output</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">3</span><span class="p">)))</span> <span class="gp">>>> </span><span class="n">jax</span><span class="o">.</span><span class="n">tree_util</span><span class="o">.</span><span class="n">tree_map</span><span class="p">(</span><span class="n">jnp</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">variables</span><span class="p">)</span> <span class="go">{'params': {'bias': (4,), 'kernel': (3, 3, 4)}}</span> <span class="gp">>>> </span><span class="n">out</span><span class="o">.</span><span class="n">shape</span> <span class="go">(1, 10, 4)</span> <span class="gp">>>> </span><span class="c1"># circular padding with stride 2</span> <span class="gp">>>> </span><span class="n">layer</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ConvTranspose</span><span class="p">(</span><span class="n">features</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">6</span><span class="p">,</span> <span class="mi">6</span><span class="p">),</span> <span class="n">strides</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span> <span class="n">padding</span><span class="o">=</span><span class="s1">'CIRCULAR'</span><span class="p">,</span> <span class="n">transpose_kernel</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">out</span><span class="p">,</span> <span class="n">variables</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">init_with_output</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">15</span><span class="p">,</span> <span class="mi">15</span><span class="p">,</span> <span class="mi">3</span><span class="p">)))</span> <span class="gp">>>> </span><span class="n">jax</span><span class="o">.</span><span class="n">tree_util</span><span class="o">.</span><span class="n">tree_map</span><span class="p">(</span><span class="n">jnp</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">variables</span><span class="p">)</span> <span class="go">{'params': {'bias': (4,), 'kernel': (6, 6, 4, 3)}}</span> <span class="gp">>>> </span><span class="n">out</span><span class="o">.</span><span class="n">shape</span> <span class="go">(1, 30, 30, 4)</span> <span class="gp">>>> </span><span class="c1"># apply lower triangle mask</span> <span class="gp">>>> </span><span class="n">mask</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">tril</span><span class="p">(</span><span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">)))</span> <span class="gp">>>> </span><span class="n">layer</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ConvTranspose</span><span class="p">(</span><span class="n">features</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,),</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s1">'VALID'</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">variables</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">3</span><span class="p">)))</span> </pre></div> </div> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.ConvTranspose.features"> <span class="sig-name descname"><span class="pre">features</span></span><a class="headerlink" href="#flax.linen.ConvTranspose.features" title="Permalink to this definition">#</a></dt> <dd><p>number of convolution filters.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>int</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.ConvTranspose.kernel_size"> <span class="sig-name descname"><span class="pre">kernel_size</span></span><a class="headerlink" href="#flax.linen.ConvTranspose.kernel_size" title="Permalink to this definition">#</a></dt> <dd><p>shape of the convolutional kernel. For 1D convolution, the kernel size can be passed as an integer, which will be interpreted as a tuple of the single integer. For all other cases, it must be a sequence of integers.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>int | collections.abc.Sequence[int]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.ConvTranspose.strides"> <span class="sig-name descname"><span class="pre">strides</span></span><a class="headerlink" href="#flax.linen.ConvTranspose.strides" title="Permalink to this definition">#</a></dt> <dd><p>an integer or a sequence of <cite>n</cite> integers, representing the inter-window strides.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>collections.abc.Sequence[int] | None</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.ConvTranspose.padding"> <span class="sig-name descname"><span class="pre">padding</span></span><a class="headerlink" href="#flax.linen.ConvTranspose.padding" title="Permalink to this definition">#</a></dt> <dd><p>either the string <cite>‘SAME’</cite>, the string <cite>‘VALID’</cite>, the string <cite>‘CIRCULAR’</cite> (periodic boundary conditions), or a sequence of <cite>n</cite> <cite>(low, high)</cite> integer pairs that give the padding to apply before and after each spatial dimension. A single int is interpreted as applying the same padding in all dims and assign a single int in a sequence causes the same padding to be used on both sides.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[str, int, collections.abc.Sequence[Union[int, tuple[int, int]]]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.ConvTranspose.kernel_dilation"> <span class="sig-name descname"><span class="pre">kernel_dilation</span></span><a class="headerlink" href="#flax.linen.ConvTranspose.kernel_dilation" title="Permalink to this definition">#</a></dt> <dd><p><code class="docutils literal notranslate"><span class="pre">None</span></code>, or an integer or a sequence of <code class="docutils literal notranslate"><span class="pre">n</span></code> integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel. Convolution with kernel dilation is also known as ‘atrous convolution’.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>collections.abc.Sequence[int] | None</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.ConvTranspose.use_bias"> <span class="sig-name descname"><span class="pre">use_bias</span></span><a class="headerlink" href="#flax.linen.ConvTranspose.use_bias" title="Permalink to this definition">#</a></dt> <dd><p>whether to add a bias to the output (default: True).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>bool</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.ConvTranspose.mask"> <span class="sig-name descname"><span class="pre">mask</span></span><a class="headerlink" href="#flax.linen.ConvTranspose.mask" title="Permalink to this definition">#</a></dt> <dd><p>Optional mask for the weights during masked convolution. The mask must be the same shape as the convolution weight matrix.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Optional[Union[jax.Array, Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.ConvTranspose.dtype"> <span class="sig-name descname"><span class="pre">dtype</span></span><a class="headerlink" href="#flax.linen.ConvTranspose.dtype" title="Permalink to this definition">#</a></dt> <dd><p>the dtype of the computation (default: infer from input and params).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Optional[Union[str, type[<em>Any</em>], numpy.dtype, jax._src.typing.SupportsDType, Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.ConvTranspose.param_dtype"> <span class="sig-name descname"><span class="pre">param_dtype</span></span><a class="headerlink" href="#flax.linen.ConvTranspose.param_dtype" title="Permalink to this definition">#</a></dt> <dd><p>the dtype passed to parameter initializers (default: float32).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[str, type[<em>Any</em>], numpy.dtype, jax._src.typing.SupportsDType, Any]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.ConvTranspose.precision"> <span class="sig-name descname"><span class="pre">precision</span></span><a class="headerlink" href="#flax.linen.ConvTranspose.precision" title="Permalink to this definition">#</a></dt> <dd><p>numerical precision of the computation see <code class="docutils literal notranslate"><span class="pre">jax.lax.Precision</span></code> for details.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[None, str, jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.ConvTranspose.kernel_init"> <span class="sig-name descname"><span class="pre">kernel_init</span></span><a class="headerlink" href="#flax.linen.ConvTranspose.kernel_init" title="Permalink to this definition">#</a></dt> <dd><p>initializer for the convolutional kernel.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.ConvTranspose.bias_init"> <span class="sig-name descname"><span class="pre">bias_init</span></span><a class="headerlink" href="#flax.linen.ConvTranspose.bias_init" title="Permalink to this definition">#</a></dt> <dd><p>initializer for the bias.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.ConvTranspose.transpose_kernel"> <span class="sig-name descname"><span class="pre">transpose_kernel</span></span><a class="headerlink" href="#flax.linen.ConvTranspose.transpose_kernel" title="Permalink to this definition">#</a></dt> <dd><p>if <code class="docutils literal notranslate"><span class="pre">True</span></code> flips spatial axes and swaps the input/output channel axes of the kernel.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>bool</p> </dd> </dl> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.ConvTranspose.__call__"> <span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">inputs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/linear.html#ConvTranspose.__call__"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.ConvTranspose.__call__" title="Permalink to this definition">#</a></dt> <dd><p>Applies a transposed convolution to the inputs.</p> <p>Behaviour mirrors of <code class="docutils literal notranslate"><span class="pre">jax.lax.conv_transpose</span></code>.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><p><strong>inputs</strong> – input data with dimensions <code class="docutils literal notranslate"><span class="pre">(*batch_dims,</span> <span class="pre">spatial_dims...,</span> <span class="pre">features).</span></code> This is the channels-last convention, i.e. NHWC for a 2d convolution and NDHWC for a 3D convolution. Note: this is different from the input convention used by <code class="docutils literal notranslate"><span class="pre">lax.conv_general_dilated</span></code>, which puts the spatial dimensions last. Note: If the input has more than 1 batch dimension, all batch dimensions are flattened into a single dimension for the convolution and restored before returning. In some cases directly vmap’ing the layer may yield better performance than this default flattening approach. If the input lacks a batch dimension it will be added for the convolution and removed n return, an allowance made to enable writing single-example code.</p> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>The convolved data.</p> </dd> </dl> </dd></dl> <p class="rubric">Methods</p> <div class="pst-scrollable-table-container"><table class="autosummary longtable table autosummary"> <colgroup> <col style="width: 10%" /> <col style="width: 90%" /> </colgroup> <tbody> </tbody> </table> </div> </dd></dl> </div> <div class="docutils container"> <dl class="py class"> <dt class="sig sig-object py" id="flax.linen.ConvLocal"> <em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">flax.linen.</span></span><span class="sig-name descname"><span class="pre">ConvLocal</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">features</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">kernel_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">strides=1</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">padding='SAME'</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">input_dilation=1</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">kernel_dilation=1</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">feature_group_count=1</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">use_bias=True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">mask=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">param_dtype=<class</span> <span class="pre">'jax.numpy.float32'></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">precision=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">kernel_init=<function</span> <span class="pre">variance_scaling.<locals>.init></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">bias_init=<function</span> <span class="pre">zeros></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">conv_general_dilated=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">conv_general_dilated_cls=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">parent=<flax.linen.module._Sentinel</span> <span class="pre">object></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">name=None</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/linear.html#ConvLocal"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.ConvLocal" title="Permalink to this definition">#</a></dt> <dd><p>Local convolution Module wrapping <code class="docutils literal notranslate"><span class="pre">lax.conv_general_dilated_local</span></code>.</p> <p>Example usage:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="kn">import</span> <span class="nn">flax.linen</span> <span class="k">as</span> <span class="nn">nn</span> <span class="gp">>>> </span><span class="kn">import</span> <span class="nn">jax</span><span class="o">,</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span> <span class="gp">>>> </span><span class="c1"># valid padding</span> <span class="gp">>>> </span><span class="n">layer</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ConvLocal</span><span class="p">(</span><span class="n">features</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,),</span> <span class="n">padding</span><span class="o">=</span><span class="s1">'VALID'</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">out</span><span class="p">,</span> <span class="n">variables</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">init_with_output</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">3</span><span class="p">)))</span> <span class="gp">>>> </span><span class="n">jax</span><span class="o">.</span><span class="n">tree_util</span><span class="o">.</span><span class="n">tree_map</span><span class="p">(</span><span class="n">jnp</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">variables</span><span class="p">)</span> <span class="go">{'params': {'bias': (6, 4), 'kernel': (6, 9, 4)}}</span> <span class="gp">>>> </span><span class="n">out</span><span class="o">.</span><span class="n">shape</span> <span class="go">(1, 6, 4)</span> <span class="gp">>>> </span><span class="c1"># circular padding with stride 2</span> <span class="gp">>>> </span><span class="n">layer</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ConvLocal</span><span class="p">(</span><span class="n">features</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">strides</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s1">'CIRCULAR'</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">out</span><span class="p">,</span> <span class="n">variables</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">init_with_output</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">3</span><span class="p">)))</span> <span class="gp">>>> </span><span class="n">jax</span><span class="o">.</span><span class="n">tree_util</span><span class="o">.</span><span class="n">tree_map</span><span class="p">(</span><span class="n">jnp</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">variables</span><span class="p">)</span> <span class="go">{'params': {'bias': (1, 4, 4), 'kernel': (1, 4, 27, 4)}}</span> <span class="gp">>>> </span><span class="n">out</span><span class="o">.</span><span class="n">shape</span> <span class="go">(1, 4, 4)</span> <span class="gp">>>> </span><span class="c1"># apply lower triangle mask</span> <span class="gp">>>> </span><span class="n">mask</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">tril</span><span class="p">(</span><span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">6</span><span class="p">,</span> <span class="mi">9</span><span class="p">,</span> <span class="mi">4</span><span class="p">)))</span> <span class="gp">>>> </span><span class="n">layer</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ConvLocal</span><span class="p">(</span><span class="n">features</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,),</span> <span class="n">mask</span><span class="o">=</span><span class="n">mask</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s1">'VALID'</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">variables</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">3</span><span class="p">)))</span> </pre></div> </div> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.ConvLocal.features"> <span class="sig-name descname"><span class="pre">features</span></span><a class="headerlink" href="#flax.linen.ConvLocal.features" title="Permalink to this definition">#</a></dt> <dd><p>number of convolution filters.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>int</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.ConvLocal.kernel_size"> <span class="sig-name descname"><span class="pre">kernel_size</span></span><a class="headerlink" href="#flax.linen.ConvLocal.kernel_size" title="Permalink to this definition">#</a></dt> <dd><p>shape of the convolutional kernel. An integer will be interpreted as a tuple of the single integer.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>int | collections.abc.Sequence[int]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.ConvLocal.strides"> <span class="sig-name descname"><span class="pre">strides</span></span><a class="headerlink" href="#flax.linen.ConvLocal.strides" title="Permalink to this definition">#</a></dt> <dd><p>an integer or a sequence of <cite>n</cite> integers, representing the inter-window strides (default: 1).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>None | int | collections.abc.Sequence[int]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.ConvLocal.padding"> <span class="sig-name descname"><span class="pre">padding</span></span><a class="headerlink" href="#flax.linen.ConvLocal.padding" title="Permalink to this definition">#</a></dt> <dd><p>either the string <code class="docutils literal notranslate"><span class="pre">'SAME'</span></code>, the string <code class="docutils literal notranslate"><span class="pre">'VALID'</span></code>, the string <code class="docutils literal notranslate"><span class="pre">'CIRCULAR'</span></code> (periodic boundary conditions), or a sequence of <code class="docutils literal notranslate"><span class="pre">n</span></code> <code class="docutils literal notranslate"><span class="pre">(low,</span> <span class="pre">high)</span></code> integer pairs that give the padding to apply before and after each spatial dimension. A single int is interpreted as applying the same padding in all dims and assign a single int in a sequence causes the same padding to be used on both sides. <code class="docutils literal notranslate"><span class="pre">'CAUSAL'</span></code> padding for a 1D convolution will left-pad the convolution axis, resulting in same-sized output.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[str, int, collections.abc.Sequence[Union[int, tuple[int, int]]]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.ConvLocal.input_dilation"> <span class="sig-name descname"><span class="pre">input_dilation</span></span><a class="headerlink" href="#flax.linen.ConvLocal.input_dilation" title="Permalink to this definition">#</a></dt> <dd><p>an integer or a sequence of <code class="docutils literal notranslate"><span class="pre">n</span></code> integers, giving the dilation factor to apply in each spatial dimension of <code class="docutils literal notranslate"><span class="pre">inputs</span></code> (default: 1). Convolution with input dilation <code class="docutils literal notranslate"><span class="pre">d</span></code> is equivalent to transposed convolution with stride <code class="docutils literal notranslate"><span class="pre">d</span></code>.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>None | int | collections.abc.Sequence[int]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.ConvLocal.kernel_dilation"> <span class="sig-name descname"><span class="pre">kernel_dilation</span></span><a class="headerlink" href="#flax.linen.ConvLocal.kernel_dilation" title="Permalink to this definition">#</a></dt> <dd><p>an integer or a sequence of <code class="docutils literal notranslate"><span class="pre">n</span></code> integers, giving the dilation factor to apply in each spatial dimension of the convolution kernel (default: 1). Convolution with kernel dilation is also known as ‘atrous convolution’.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>None | int | collections.abc.Sequence[int]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.ConvLocal.feature_group_count"> <span class="sig-name descname"><span class="pre">feature_group_count</span></span><a class="headerlink" href="#flax.linen.ConvLocal.feature_group_count" title="Permalink to this definition">#</a></dt> <dd><p>integer, default 1. If specified divides the input features into groups.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>int</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.ConvLocal.use_bias"> <span class="sig-name descname"><span class="pre">use_bias</span></span><a class="headerlink" href="#flax.linen.ConvLocal.use_bias" title="Permalink to this definition">#</a></dt> <dd><p>whether to add a bias to the output (default: True).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>bool</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.ConvLocal.mask"> <span class="sig-name descname"><span class="pre">mask</span></span><a class="headerlink" href="#flax.linen.ConvLocal.mask" title="Permalink to this definition">#</a></dt> <dd><p>Optional mask for the weights during masked convolution. The mask must be the same shape as the convolution weight matrix.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Optional[Union[jax.Array, Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.ConvLocal.dtype"> <span class="sig-name descname"><span class="pre">dtype</span></span><a class="headerlink" href="#flax.linen.ConvLocal.dtype" title="Permalink to this definition">#</a></dt> <dd><p>the dtype of the computation (default: infer from input and params).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Optional[Union[str, type[<em>Any</em>], numpy.dtype, jax._src.typing.SupportsDType, Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.ConvLocal.param_dtype"> <span class="sig-name descname"><span class="pre">param_dtype</span></span><a class="headerlink" href="#flax.linen.ConvLocal.param_dtype" title="Permalink to this definition">#</a></dt> <dd><p>the dtype passed to parameter initializers (default: float32).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[str, type[<em>Any</em>], numpy.dtype, jax._src.typing.SupportsDType, Any]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.ConvLocal.precision"> <span class="sig-name descname"><span class="pre">precision</span></span><a class="headerlink" href="#flax.linen.ConvLocal.precision" title="Permalink to this definition">#</a></dt> <dd><p>numerical precision of the computation see <code class="docutils literal notranslate"><span class="pre">jax.lax.Precision</span></code> for details.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[None, str, jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.ConvLocal.kernel_init"> <span class="sig-name descname"><span class="pre">kernel_init</span></span><a class="headerlink" href="#flax.linen.ConvLocal.kernel_init" title="Permalink to this definition">#</a></dt> <dd><p>initializer for the convolutional kernel.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.ConvLocal.bias_init"> <span class="sig-name descname"><span class="pre">bias_init</span></span><a class="headerlink" href="#flax.linen.ConvLocal.bias_init" title="Permalink to this definition">#</a></dt> <dd><p>initializer for the bias.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]</p> </dd> </dl> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.ConvLocal.__call__"> <span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">inputs</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#flax.linen.ConvLocal.__call__" title="Permalink to this definition">#</a></dt> <dd><p>Applies a (potentially unshared) convolution to the inputs.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><p><strong>inputs</strong> – input data with dimensions <code class="docutils literal notranslate"><span class="pre">(*batch_dims,</span> <span class="pre">spatial_dims...,</span> <span class="pre">features)</span></code>. This is the channels-last convention, i.e. NHWC for a 2d convolution and NDHWC for a 3D convolution. Note: this is different from the input convention used by <code class="docutils literal notranslate"><span class="pre">lax.conv_general_dilated</span></code>, which puts the spatial dimensions last. Note: If the input has more than 1 batch dimension, all batch dimensions are flattened into a single dimension for the convolution and restored before returning. In some cases directly vmap’ing the layer may yield better performance than this default flattening approach. If the input lacks a batch dimension it will be added for the convolution and removed n return, an allowance made to enable writing single-example code.</p> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>The convolved data.</p> </dd> </dl> </dd></dl> <p class="rubric">Methods</p> <div class="pst-scrollable-table-container"><table class="autosummary longtable table autosummary"> <colgroup> <col style="width: 10%" /> <col style="width: 90%" /> </colgroup> <tbody> </tbody> </table> </div> </dd></dl> </div> <div class="docutils container"> <dl class="py class"> <dt class="sig sig-object py" id="flax.linen.Einsum"> <em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">flax.linen.</span></span><span class="sig-name descname"><span class="pre">Einsum</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">shape</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">einsum_str=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">use_bias=True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">param_dtype=<class</span> <span class="pre">'jax.numpy.float32'></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">precision=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">kernel_init=<function</span> <span class="pre">variance_scaling.<locals>.init></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">bias_init=<function</span> <span class="pre">zeros></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">parent=<flax.linen.module._Sentinel</span> <span class="pre">object></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">name=None</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/linear.html#Einsum"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.Einsum" title="Permalink to this definition">#</a></dt> <dd><p>An einsum transformation with learnable kernel and bias.</p> <p>Example usage:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="kn">import</span> <span class="nn">flax.linen</span> <span class="k">as</span> <span class="nn">nn</span> <span class="gp">>>> </span><span class="kn">import</span> <span class="nn">jax</span><span class="o">,</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span> <span class="gp">>>> </span><span class="n">layer</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Einsum</span><span class="p">((</span><span class="mi">5</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">7</span><span class="p">),</span> <span class="s1">'abc,cde->abde'</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">variables</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">5</span><span class="p">)))</span> <span class="gp">>>> </span><span class="n">jax</span><span class="o">.</span><span class="n">tree_util</span><span class="o">.</span><span class="n">tree_map</span><span class="p">(</span><span class="n">jnp</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">variables</span><span class="p">)</span> <span class="go">{'params': {'bias': (6, 7), 'kernel': (5, 6, 7)}}</span> </pre></div> </div> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.Einsum.shape"> <span class="sig-name descname"><span class="pre">shape</span></span><a class="headerlink" href="#flax.linen.Einsum.shape" title="Permalink to this definition">#</a></dt> <dd><p>the shape of the kernel.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>collections.abc.Sequence[int]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.Einsum.einsum_str"> <span class="sig-name descname"><span class="pre">einsum_str</span></span><a class="headerlink" href="#flax.linen.Einsum.einsum_str" title="Permalink to this definition">#</a></dt> <dd><p>a string to denote the einsum equation. The equation must have exactly two operands, the lhs being the input passed in, and the rhs being the learnable kernel. Exactly one of <code class="docutils literal notranslate"><span class="pre">einsum_str</span></code> in the constructor argument and call argument must be not None, while the other must be None.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>str | None</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.Einsum.use_bias"> <span class="sig-name descname"><span class="pre">use_bias</span></span><a class="headerlink" href="#flax.linen.Einsum.use_bias" title="Permalink to this definition">#</a></dt> <dd><p>whether to add a bias to the output (default: True).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>bool</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.Einsum.dtype"> <span class="sig-name descname"><span class="pre">dtype</span></span><a class="headerlink" href="#flax.linen.Einsum.dtype" title="Permalink to this definition">#</a></dt> <dd><p>the dtype of the computation (default: infer from input and params).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Optional[Union[str, type[<em>Any</em>], numpy.dtype, jax._src.typing.SupportsDType, Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.Einsum.param_dtype"> <span class="sig-name descname"><span class="pre">param_dtype</span></span><a class="headerlink" href="#flax.linen.Einsum.param_dtype" title="Permalink to this definition">#</a></dt> <dd><p>the dtype passed to parameter initializers (default: float32).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[str, type[<em>Any</em>], numpy.dtype, jax._src.typing.SupportsDType, Any]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.Einsum.precision"> <span class="sig-name descname"><span class="pre">precision</span></span><a class="headerlink" href="#flax.linen.Einsum.precision" title="Permalink to this definition">#</a></dt> <dd><p>numerical precision of the computation see <code class="docutils literal notranslate"><span class="pre">jax.lax.Precision</span></code> for details.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[None, str, jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.Einsum.kernel_init"> <span class="sig-name descname"><span class="pre">kernel_init</span></span><a class="headerlink" href="#flax.linen.Einsum.kernel_init" title="Permalink to this definition">#</a></dt> <dd><p>initializer function for the weight matrix.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.Einsum.bias_init"> <span class="sig-name descname"><span class="pre">bias_init</span></span><a class="headerlink" href="#flax.linen.Einsum.bias_init" title="Permalink to this definition">#</a></dt> <dd><p>initializer function for the bias.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]</p> </dd> </dl> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.Einsum.__call__"> <span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">inputs</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">einsum_str</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/linear.html#Einsum.__call__"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.Einsum.__call__" title="Permalink to this definition">#</a></dt> <dd><p>Applies a linear transformation to the inputs along the last dimension.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>inputs</strong> – The nd-array to be transformed.</p></li> <li><p><strong>einsum_str</strong> – a string to denote the einsum equation. The equation must have exactly two operands, the lhs being the input passed in, and the rhs being the learnable kernel. The <code class="docutils literal notranslate"><span class="pre">einsum_str</span></code> passed into the call method will take precedence over the <code class="docutils literal notranslate"><span class="pre">einsum_str</span></code> passed into the constructor.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>The transformed input.</p> </dd> </dl> </dd></dl> <p class="rubric">Methods</p> <div class="pst-scrollable-table-container"><table class="autosummary longtable table autosummary"> <colgroup> <col style="width: 10%" /> <col style="width: 90%" /> </colgroup> <tbody> </tbody> </table> </div> </dd></dl> </div> <div class="docutils container"> <dl class="py class"> <dt class="sig sig-object py" id="flax.linen.Embed"> <em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">flax.linen.</span></span><span class="sig-name descname"><span class="pre">Embed</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">num_embeddings</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">features</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">param_dtype=<class</span> <span class="pre">'jax.numpy.float32'></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">embedding_init=<function</span> <span class="pre">variance_scaling.<locals>.init></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">parent=<flax.linen.module._Sentinel</span> <span class="pre">object></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">name=None</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/linear.html#Embed"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.Embed" title="Permalink to this definition">#</a></dt> <dd><p>Embedding Module.</p> <p>A parameterized function from integers [0, <code class="docutils literal notranslate"><span class="pre">num_embeddings</span></code>) to <code class="docutils literal notranslate"><span class="pre">features</span></code>-dimensional vectors. This <code class="docutils literal notranslate"><span class="pre">Module</span></code> will create an <code class="docutils literal notranslate"><span class="pre">embedding</span></code> matrix with shape <code class="docutils literal notranslate"><span class="pre">(num_embeddings,</span> <span class="pre">features)</span></code>. When calling this layer, the input values will be used to 0-index into the <code class="docutils literal notranslate"><span class="pre">embedding</span></code> matrix. Indexing on a value greater than or equal to <code class="docutils literal notranslate"><span class="pre">num_embeddings</span></code> will result in <code class="docutils literal notranslate"><span class="pre">nan</span></code> values. When <code class="docutils literal notranslate"><span class="pre">num_embeddings</span></code> equals to 1, it will broadcast the <code class="docutils literal notranslate"><span class="pre">embedding</span></code> matrix to input shape with <code class="docutils literal notranslate"><span class="pre">features</span></code> dimension appended.</p> <p>Example usage:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="kn">import</span> <span class="nn">flax.linen</span> <span class="k">as</span> <span class="nn">nn</span> <span class="gp">>>> </span><span class="kn">import</span> <span class="nn">jax</span><span class="o">,</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span> <span class="gp">>>> </span><span class="n">layer</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Embed</span><span class="p">(</span><span class="n">num_embeddings</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">features</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">indices_input</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">],</span> <span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">2</span><span class="p">,</span> <span class="o">-</span><span class="mi">3</span><span class="p">]])</span> <span class="gp">>>> </span><span class="n">variables</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">indices_input</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">variables</span> <span class="go">{'params': {'embedding': Array([[-0.28884724, 0.19018005, -0.414205 ],</span> <span class="go"> [-0.11768015, -0.54618824, -0.3789283 ],</span> <span class="go"> [ 0.30428642, 0.49511626, 0.01706631],</span> <span class="go"> [-0.0982546 , -0.43055868, 0.20654906],</span> <span class="go"> [-0.688412 , -0.46882293, 0.26723292]], dtype=float32)}}</span> <span class="gp">>>> </span><span class="c1"># get the first three and last three embeddings</span> <span class="gp">>>> </span><span class="n">layer</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">indices_input</span><span class="p">)</span> <span class="go">Array([[[-0.28884724, 0.19018005, -0.414205 ],</span> <span class="go"> [-0.11768015, -0.54618824, -0.3789283 ],</span> <span class="go"> [ 0.30428642, 0.49511626, 0.01706631]],</span> <span class="go"> [[-0.688412 , -0.46882293, 0.26723292],</span> <span class="go"> [-0.0982546 , -0.43055868, 0.20654906],</span> <span class="go"> [ 0.30428642, 0.49511626, 0.01706631]]], dtype=float32)</span> </pre></div> </div> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.Embed.num_embeddings"> <span class="sig-name descname"><span class="pre">num_embeddings</span></span><a class="headerlink" href="#flax.linen.Embed.num_embeddings" title="Permalink to this definition">#</a></dt> <dd><p>number of embeddings / vocab size.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>int</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.Embed.features"> <span class="sig-name descname"><span class="pre">features</span></span><a class="headerlink" href="#flax.linen.Embed.features" title="Permalink to this definition">#</a></dt> <dd><p>number of feature dimensions for each embedding.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>int</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.Embed.dtype"> <span class="sig-name descname"><span class="pre">dtype</span></span><a class="headerlink" href="#flax.linen.Embed.dtype" title="Permalink to this definition">#</a></dt> <dd><p>the dtype of the embedding vectors (default: same as embedding).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Optional[Union[str, type[<em>Any</em>], numpy.dtype, jax._src.typing.SupportsDType, Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.Embed.param_dtype"> <span class="sig-name descname"><span class="pre">param_dtype</span></span><a class="headerlink" href="#flax.linen.Embed.param_dtype" title="Permalink to this definition">#</a></dt> <dd><p>the dtype passed to parameter initializers (default: float32).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[str, type[<em>Any</em>], numpy.dtype, jax._src.typing.SupportsDType, Any]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.Embed.embedding_init"> <span class="sig-name descname"><span class="pre">embedding_init</span></span><a class="headerlink" href="#flax.linen.Embed.embedding_init" title="Permalink to this definition">#</a></dt> <dd><p>embedding initializer.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]</p> </dd> </dl> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.Embed.__call__"> <span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">inputs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/linear.html#Embed.__call__"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.Embed.__call__" title="Permalink to this definition">#</a></dt> <dd><p>Embeds the inputs along the last dimension.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><p><strong>inputs</strong> – input data, all dimensions are considered batch dimensions. Values in the input array must be integers.</p> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>Output which is embedded input data. The output shape follows the input, with an additional <code class="docutils literal notranslate"><span class="pre">features</span></code> dimension appended.</p> </dd> </dl> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.Embed.attend"> <span class="sig-name descname"><span class="pre">attend</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">query</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/linear.html#Embed.attend"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.Embed.attend" title="Permalink to this definition">#</a></dt> <dd><p>Attend over the embedding using a query array.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><p><strong>query</strong> – array with last dimension equal the feature depth <code class="docutils literal notranslate"><span class="pre">features</span></code> of the embedding.</p> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>An array with final dim <code class="docutils literal notranslate"><span class="pre">num_embeddings</span></code> corresponding to the batched inner-product of the array of query vectors against each embedding. Commonly used for weight-sharing between embeddings and logit transform in NLP models.</p> </dd> </dl> </dd></dl> <p class="rubric">Methods</p> <div class="pst-scrollable-table-container"><table class="autosummary longtable table autosummary"> <colgroup> <col style="width: 10%" /> <col style="width: 90%" /> </colgroup> <tbody> <tr class="row-odd"><td><p><a class="reference internal" href="#flax.linen.Embed.attend" title="flax.linen.Embed.attend"><code class="xref py py-obj docutils literal notranslate"><span class="pre">attend</span></code></a>(query)</p></td> <td><p>Attend over the embedding using a query array.</p></td> </tr> </tbody> </table> </div> </dd></dl> </div> </div> <div class="section" id="pooling"> <h2>Pooling<a class="headerlink" href="#pooling" title="Permalink to this heading">#</a></h2> <dl class="py function"> <dt class="sig sig-object py" id="flax.linen.max_pool"> <span class="sig-prename descclassname"><span class="pre">flax.linen.</span></span><span class="sig-name descname"><span class="pre">max_pool</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">inputs</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">window_shape</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">strides</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">padding</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">'VALID'</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/pooling.html#max_pool"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.max_pool" title="Permalink to this definition">#</a></dt> <dd><p>Pools the input by taking the maximum of a window slice.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>inputs</strong> – input data with dimensions (batch, window dims…, features).</p></li> <li><p><strong>window_shape</strong> – a shape tuple defining the window to reduce over.</p></li> <li><p><strong>strides</strong> – a sequence of <code class="docutils literal notranslate"><span class="pre">n</span></code> integers, representing the inter-window strides (default: <code class="docutils literal notranslate"><span class="pre">(1,</span> <span class="pre">...,</span> <span class="pre">1)</span></code>).</p></li> <li><p><strong>padding</strong> – either the string <code class="docutils literal notranslate"><span class="pre">'SAME'</span></code>, the string <code class="docutils literal notranslate"><span class="pre">'VALID'</span></code>, or a sequence of <code class="docutils literal notranslate"><span class="pre">n</span></code> <code class="docutils literal notranslate"><span class="pre">(low,</span> <span class="pre">high)</span></code> integer pairs that give the padding to apply before and after each spatial dimension (default: <code class="docutils literal notranslate"><span class="pre">'VALID'</span></code>).</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>The maximum for each window slice.</p> </dd> </dl> </dd></dl> <dl class="py function"> <dt class="sig sig-object py" id="flax.linen.avg_pool"> <span class="sig-prename descclassname"><span class="pre">flax.linen.</span></span><span class="sig-name descname"><span class="pre">avg_pool</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">inputs</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">window_shape</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">strides</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">padding</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">'VALID'</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">count_include_pad</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">True</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/pooling.html#avg_pool"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.avg_pool" title="Permalink to this definition">#</a></dt> <dd><p>Pools the input by taking the average over a window.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>inputs</strong> – input data with dimensions (batch, window dims…, features).</p></li> <li><p><strong>window_shape</strong> – a shape tuple defining the window to reduce over.</p></li> <li><p><strong>strides</strong> – a sequence of <code class="docutils literal notranslate"><span class="pre">n</span></code> integers, representing the inter-window strides (default: <code class="docutils literal notranslate"><span class="pre">(1,</span> <span class="pre">...,</span> <span class="pre">1)</span></code>).</p></li> <li><p><strong>padding</strong> – either the string <code class="docutils literal notranslate"><span class="pre">'SAME'</span></code>, the string <code class="docutils literal notranslate"><span class="pre">'VALID'</span></code>, or a sequence of <code class="docutils literal notranslate"><span class="pre">n</span></code> <code class="docutils literal notranslate"><span class="pre">(low,</span> <span class="pre">high)</span></code> integer pairs that give the padding to apply before and after each spatial dimension (default: <code class="docutils literal notranslate"><span class="pre">'VALID'</span></code>).</p></li> <li><p><strong>count_include_pad</strong> – a boolean whether to include padded tokens in the average calculation (default: <code class="docutils literal notranslate"><span class="pre">True</span></code>).</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>The average for each window slice.</p> </dd> </dl> </dd></dl> <dl class="py function"> <dt class="sig sig-object py" id="flax.linen.pool"> <span class="sig-prename descclassname"><span class="pre">flax.linen.</span></span><span class="sig-name descname"><span class="pre">pool</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">inputs</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">init</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">reduce_fn</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">window_shape</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">strides</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">padding</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/pooling.html#pool"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.pool" title="Permalink to this definition">#</a></dt> <dd><p>Helper function to define pooling functions.</p> <p>Pooling functions are implemented using the ReduceWindow XLA op.</p> <div class="admonition note"> <p class="admonition-title">Note</p> <p>Be aware that pooling is not generally differentiable. That means providing a reduce_fn that is differentiable does not imply that pool is differentiable.</p> </div> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>inputs</strong> – input data with dimensions (batch, window dims…, features).</p></li> <li><p><strong>init</strong> – the initial value for the reduction</p></li> <li><p><strong>reduce_fn</strong> – a reduce function of the form <code class="docutils literal notranslate"><span class="pre">(T,</span> <span class="pre">T)</span> <span class="pre">-></span> <span class="pre">T</span></code>.</p></li> <li><p><strong>window_shape</strong> – a shape tuple defining the window to reduce over.</p></li> <li><p><strong>strides</strong> – a sequence of <code class="docutils literal notranslate"><span class="pre">n</span></code> integers, representing the inter-window strides (default: <code class="docutils literal notranslate"><span class="pre">(1,</span> <span class="pre">...,</span> <span class="pre">1)</span></code>).</p></li> <li><p><strong>padding</strong> – either the string <code class="docutils literal notranslate"><span class="pre">'SAME'</span></code>, the string <code class="docutils literal notranslate"><span class="pre">'VALID'</span></code>, or a sequence of <code class="docutils literal notranslate"><span class="pre">n</span></code> <code class="docutils literal notranslate"><span class="pre">(low,</span> <span class="pre">high)</span></code> integer pairs that give the padding to apply before and after each spatial dimension.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>The output of the reduction for each window slice.</p> </dd> </dl> </dd></dl> </div> <div class="section" id="normalization"> <h2>Normalization<a class="headerlink" href="#normalization" title="Permalink to this heading">#</a></h2> <div class="docutils container"> <dl class="py class"> <dt class="sig sig-object py" id="flax.linen.BatchNorm"> <em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">flax.linen.</span></span><span class="sig-name descname"><span class="pre">BatchNorm</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">use_running_average=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">axis=-1</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">momentum=0.99</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">epsilon=1e-05</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">param_dtype=<class</span> <span class="pre">'jax.numpy.float32'></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">use_bias=True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">use_scale=True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">bias_init=<function</span> <span class="pre">zeros></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">scale_init=<function</span> <span class="pre">ones></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">axis_name=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">axis_index_groups=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">use_fast_variance=True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">force_float32_reductions=True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">parent=<flax.linen.module._Sentinel</span> <span class="pre">object></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">name=None</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/normalization.html#BatchNorm"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.BatchNorm" title="Permalink to this definition">#</a></dt> <dd><p>BatchNorm Module.</p> <p>Usage Note: If we define a model with BatchNorm, for example:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="kn">import</span> <span class="nn">flax.linen</span> <span class="k">as</span> <span class="nn">nn</span> <span class="gp">>>> </span><span class="kn">import</span> <span class="nn">jax</span><span class="o">,</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span> <span class="gp">>>> </span><span class="n">BN</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm</span><span class="p">(</span><span class="n">momentum</span><span class="o">=</span><span class="mf">0.9</span><span class="p">,</span> <span class="n">epsilon</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">jnp</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> </pre></div> </div> <p>The initialized variables dict will contain, in addition to a ‘params’ collection, a separate ‘batch_stats’ collection that will contain all the running statistics for all the BatchNorm layers in a model:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="n">x</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">6</span><span class="p">))</span> <span class="gp">>>> </span><span class="n">variables</span> <span class="o">=</span> <span class="n">BN</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">x</span><span class="p">,</span> <span class="n">use_running_average</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">jax</span><span class="o">.</span><span class="n">tree_util</span><span class="o">.</span><span class="n">tree_map</span><span class="p">(</span><span class="n">jnp</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">variables</span><span class="p">)</span> <span class="go">{'batch_stats': {'mean': (6,), 'var': (6,)}, 'params': {'bias': (6,), 'scale': (6,)}}</span> </pre></div> </div> <p>We then update the batch_stats during training by specifying that the <code class="docutils literal notranslate"><span class="pre">batch_stats</span></code> collection is mutable in the <code class="docutils literal notranslate"><span class="pre">apply</span></code> method for our module.:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="n">y</span><span class="p">,</span> <span class="n">new_batch_stats</span> <span class="o">=</span> <span class="n">BN</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">mutable</span><span class="o">=</span><span class="p">[</span><span class="s1">'batch_stats'</span><span class="p">],</span> <span class="n">use_running_average</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> </pre></div> </div> <p>During eval we would define BN with <code class="docutils literal notranslate"><span class="pre">use_running_average=True</span></code> and use the batch_stats collection from training to set the statistics. In this case we are not mutating the batch statistics collection, and needn’t mark it mutable:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="n">y</span> <span class="o">=</span> <span class="n">BN</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">mutable</span><span class="o">=</span><span class="p">[</span><span class="s1">'batch_stats'</span><span class="p">],</span> <span class="n">use_running_average</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> </pre></div> </div> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.BatchNorm.use_running_average"> <span class="sig-name descname"><span class="pre">use_running_average</span></span><a class="headerlink" href="#flax.linen.BatchNorm.use_running_average" title="Permalink to this definition">#</a></dt> <dd><p>if True, the statistics stored in batch_stats will be used instead of computing the batch statistics on the input.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>bool | None</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.BatchNorm.axis"> <span class="sig-name descname"><span class="pre">axis</span></span><a class="headerlink" href="#flax.linen.BatchNorm.axis" title="Permalink to this definition">#</a></dt> <dd><p>the feature or non-batch axis of the input.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>int</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.BatchNorm.momentum"> <span class="sig-name descname"><span class="pre">momentum</span></span><a class="headerlink" href="#flax.linen.BatchNorm.momentum" title="Permalink to this definition">#</a></dt> <dd><p>decay rate for the exponential moving average of the batch statistics.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>float</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.BatchNorm.epsilon"> <span class="sig-name descname"><span class="pre">epsilon</span></span><a class="headerlink" href="#flax.linen.BatchNorm.epsilon" title="Permalink to this definition">#</a></dt> <dd><p>a small float added to variance to avoid dividing by zero.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>float</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.BatchNorm.dtype"> <span class="sig-name descname"><span class="pre">dtype</span></span><a class="headerlink" href="#flax.linen.BatchNorm.dtype" title="Permalink to this definition">#</a></dt> <dd><p>the dtype of the result (default: infer from input and params).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Optional[Union[str, type[<em>Any</em>], numpy.dtype, jax._src.typing.SupportsDType, Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.BatchNorm.param_dtype"> <span class="sig-name descname"><span class="pre">param_dtype</span></span><a class="headerlink" href="#flax.linen.BatchNorm.param_dtype" title="Permalink to this definition">#</a></dt> <dd><p>the dtype passed to parameter initializers (default: float32).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[str, type[<em>Any</em>], numpy.dtype, jax._src.typing.SupportsDType, Any]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.BatchNorm.use_bias"> <span class="sig-name descname"><span class="pre">use_bias</span></span><a class="headerlink" href="#flax.linen.BatchNorm.use_bias" title="Permalink to this definition">#</a></dt> <dd><p>if True, bias (beta) is added.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>bool</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.BatchNorm.use_scale"> <span class="sig-name descname"><span class="pre">use_scale</span></span><a class="headerlink" href="#flax.linen.BatchNorm.use_scale" title="Permalink to this definition">#</a></dt> <dd><p>if True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>bool</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.BatchNorm.bias_init"> <span class="sig-name descname"><span class="pre">bias_init</span></span><a class="headerlink" href="#flax.linen.BatchNorm.bias_init" title="Permalink to this definition">#</a></dt> <dd><p>initializer for bias, by default, zero.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.BatchNorm.scale_init"> <span class="sig-name descname"><span class="pre">scale_init</span></span><a class="headerlink" href="#flax.linen.BatchNorm.scale_init" title="Permalink to this definition">#</a></dt> <dd><p>initializer for scale, by default, one.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.BatchNorm.axis_name"> <span class="sig-name descname"><span class="pre">axis_name</span></span><a class="headerlink" href="#flax.linen.BatchNorm.axis_name" title="Permalink to this definition">#</a></dt> <dd><p>the axis name used to combine batch statistics from multiple devices. See <code class="docutils literal notranslate"><span class="pre">jax.pmap</span></code> for a description of axis names (default: None). Note, this is only used for pmap and shard map. For SPMD jit, you do not need to manually synchronize. Just make sure that the axes are correctly annotated and XLA:SPMD will insert the necessary collectives.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>str | None</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.BatchNorm.axis_index_groups"> <span class="sig-name descname"><span class="pre">axis_index_groups</span></span><a class="headerlink" href="#flax.linen.BatchNorm.axis_index_groups" title="Permalink to this definition">#</a></dt> <dd><p>groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, <code class="docutils literal notranslate"><span class="pre">[[0,</span> <span class="pre">1],</span> <span class="pre">[2,</span> <span class="pre">3]]</span></code> would independently batch-normalize over the examples on the first two and last two devices. See <code class="docutils literal notranslate"><span class="pre">jax.lax.psum</span></code> for more details.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Any</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.BatchNorm.use_fast_variance"> <span class="sig-name descname"><span class="pre">use_fast_variance</span></span><a class="headerlink" href="#flax.linen.BatchNorm.use_fast_variance" title="Permalink to this definition">#</a></dt> <dd><p>If true, use a faster, but less numerically stable, calculation for the variance.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>bool</p> </dd> </dl> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.BatchNorm.__call__"> <span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">x</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">use_running_average</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">mask</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/normalization.html#BatchNorm.__call__"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.BatchNorm.__call__" title="Permalink to this definition">#</a></dt> <dd><p>Normalizes the input using batch statistics.</p> <div class="admonition note"> <p class="admonition-title">Note</p> <p>During initialization (when <code class="docutils literal notranslate"><span class="pre">self.is_initializing()</span></code> is <code class="docutils literal notranslate"><span class="pre">True</span></code>) the running average of the batch statistics will not be updated. Therefore, the inputs fed during initialization don’t need to match that of the actual input distribution and the reduction axis (set with <code class="docutils literal notranslate"><span class="pre">axis_name</span></code>) does not have to exist.</p> </div> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>x</strong> – the input to be normalized.</p></li> <li><p><strong>use_running_average</strong> – if true, the statistics stored in batch_stats will be used instead of computing the batch statistics on the input.</p></li> <li><p><strong>mask</strong> – Binary array of shape broadcastable to <code class="docutils literal notranslate"><span class="pre">inputs</span></code> tensor, indicating the positions for which the mean and variance should be computed.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>Normalized inputs (the same shape as inputs).</p> </dd> </dl> </dd></dl> <p class="rubric">Methods</p> <div class="pst-scrollable-table-container"><table class="autosummary longtable table autosummary"> <colgroup> <col style="width: 10%" /> <col style="width: 90%" /> </colgroup> <tbody> </tbody> </table> </div> </dd></dl> </div> <div class="docutils container"> <dl class="py class"> <dt class="sig sig-object py" id="flax.linen.LayerNorm"> <em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">flax.linen.</span></span><span class="sig-name descname"><span class="pre">LayerNorm</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">epsilon=1e-06</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">param_dtype=<class</span> <span class="pre">'jax.numpy.float32'></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">use_bias=True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">use_scale=True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">bias_init=<function</span> <span class="pre">zeros></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">scale_init=<function</span> <span class="pre">ones></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">reduction_axes=-1</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">feature_axes=-1</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">axis_name=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">axis_index_groups=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">use_fast_variance=True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">force_float32_reductions=True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">parent=<flax.linen.module._Sentinel</span> <span class="pre">object></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">name=None</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/normalization.html#LayerNorm"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.LayerNorm" title="Permalink to this definition">#</a></dt> <dd><p>Layer normalization (<a class="reference external" href="https://arxiv.org/abs/1607.06450">https://arxiv.org/abs/1607.06450</a>).</p> <p>LayerNorm normalizes the activations of the layer for each given example in a batch independently, rather than across a batch like Batch Normalization. i.e. applies a transformation that maintains the mean activation within each example close to 0 and the activation standard deviation close to 1.</p> <div class="admonition note"> <p class="admonition-title">Note</p> <p>This normalization operation is identical to InstanceNorm and GroupNorm; the difference is simply which axes are reduced and the shape of the feature axes (i.e. the shape of the learnable scale and bias parameters).</p> </div> <p>Example usage:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="kn">import</span> <span class="nn">flax.linen</span> <span class="k">as</span> <span class="nn">nn</span> <span class="gp">>>> </span><span class="kn">import</span> <span class="nn">jax</span> <span class="gp">>>> </span><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> <span class="gp">>>> </span><span class="n">x</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">6</span><span class="p">))</span> <span class="gp">>>> </span><span class="n">layer</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">()</span> <span class="gp">>>> </span><span class="n">variables</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">variables</span> <span class="go">{'params': {'scale': Array([1., 1., 1., 1., 1., 1.], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0.], dtype=float32)}}</span> <span class="gp">>>> </span><span class="n">y</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">y</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">(</span><span class="n">reduction_axes</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">y2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">GroupNorm</span><span class="p">(</span><span class="n">num_groups</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">np</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_allclose</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="n">y2</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">y</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">(</span><span class="n">reduction_axes</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span> <span class="n">feature_axes</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">y2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">InstanceNorm</span><span class="p">(</span><span class="n">feature_axes</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">np</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_allclose</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="n">y2</span><span class="p">)</span> </pre></div> </div> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.LayerNorm.epsilon"> <span class="sig-name descname"><span class="pre">epsilon</span></span><a class="headerlink" href="#flax.linen.LayerNorm.epsilon" title="Permalink to this definition">#</a></dt> <dd><p>A small float added to variance to avoid dividing by zero.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>float</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.LayerNorm.dtype"> <span class="sig-name descname"><span class="pre">dtype</span></span><a class="headerlink" href="#flax.linen.LayerNorm.dtype" title="Permalink to this definition">#</a></dt> <dd><p>the dtype of the result (default: infer from input and params).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Optional[Union[str, type[<em>Any</em>], numpy.dtype, jax._src.typing.SupportsDType, Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.LayerNorm.param_dtype"> <span class="sig-name descname"><span class="pre">param_dtype</span></span><a class="headerlink" href="#flax.linen.LayerNorm.param_dtype" title="Permalink to this definition">#</a></dt> <dd><p>the dtype passed to parameter initializers (default: float32).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[str, type[<em>Any</em>], numpy.dtype, jax._src.typing.SupportsDType, Any]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.LayerNorm.use_bias"> <span class="sig-name descname"><span class="pre">use_bias</span></span><a class="headerlink" href="#flax.linen.LayerNorm.use_bias" title="Permalink to this definition">#</a></dt> <dd><p>If True, bias (beta) is added.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>bool</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.LayerNorm.use_scale"> <span class="sig-name descname"><span class="pre">use_scale</span></span><a class="headerlink" href="#flax.linen.LayerNorm.use_scale" title="Permalink to this definition">#</a></dt> <dd><p>If True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>bool</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.LayerNorm.bias_init"> <span class="sig-name descname"><span class="pre">bias_init</span></span><a class="headerlink" href="#flax.linen.LayerNorm.bias_init" title="Permalink to this definition">#</a></dt> <dd><p>Initializer for bias, by default, zero.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.LayerNorm.scale_init"> <span class="sig-name descname"><span class="pre">scale_init</span></span><a class="headerlink" href="#flax.linen.LayerNorm.scale_init" title="Permalink to this definition">#</a></dt> <dd><p>Initializer for scale, by default, one.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.LayerNorm.reduction_axes"> <span class="sig-name descname"><span class="pre">reduction_axes</span></span><a class="headerlink" href="#flax.linen.LayerNorm.reduction_axes" title="Permalink to this definition">#</a></dt> <dd><p>Axes for computing normalization statistics.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[int, collections.abc.Sequence[int]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.LayerNorm.feature_axes"> <span class="sig-name descname"><span class="pre">feature_axes</span></span><a class="headerlink" href="#flax.linen.LayerNorm.feature_axes" title="Permalink to this definition">#</a></dt> <dd><p>Feature axes for learned bias and scaling.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[int, collections.abc.Sequence[int]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.LayerNorm.axis_name"> <span class="sig-name descname"><span class="pre">axis_name</span></span><a class="headerlink" href="#flax.linen.LayerNorm.axis_name" title="Permalink to this definition">#</a></dt> <dd><p>the axis name used to combine batch statistics from multiple devices. See <code class="docutils literal notranslate"><span class="pre">jax.pmap</span></code> for a description of axis names (default: None). This is only needed if the model is subdivided across devices, i.e. the array being normalized is sharded across devices within a pmap or shard map. For SPMD jit, you do not need to manually synchronize. Just make sure that the axes are correctly annotated and XLA:SPMD will insert the necessary collectives.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>str | None</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.LayerNorm.axis_index_groups"> <span class="sig-name descname"><span class="pre">axis_index_groups</span></span><a class="headerlink" href="#flax.linen.LayerNorm.axis_index_groups" title="Permalink to this definition">#</a></dt> <dd><p>groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, <code class="docutils literal notranslate"><span class="pre">[[0,</span> <span class="pre">1],</span> <span class="pre">[2,</span> <span class="pre">3]]</span></code> would independently batch-normalize over the examples on the first two and last two devices. See <code class="docutils literal notranslate"><span class="pre">jax.lax.psum</span></code> for more details.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Any</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.LayerNorm.use_fast_variance"> <span class="sig-name descname"><span class="pre">use_fast_variance</span></span><a class="headerlink" href="#flax.linen.LayerNorm.use_fast_variance" title="Permalink to this definition">#</a></dt> <dd><p>If true, use a faster, but less numerically stable, calculation for the variance.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>bool</p> </dd> </dl> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.LayerNorm.__call__"> <span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">x</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">mask</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/normalization.html#LayerNorm.__call__"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.LayerNorm.__call__" title="Permalink to this definition">#</a></dt> <dd><p>Applies layer normalization on the input.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>x</strong> – the inputs</p></li> <li><p><strong>mask</strong> – Binary array of shape broadcastable to <code class="docutils literal notranslate"><span class="pre">inputs</span></code> tensor, indicating the positions for which the mean and variance should be computed.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>Normalized inputs (the same shape as inputs).</p> </dd> </dl> </dd></dl> <p class="rubric">Methods</p> <div class="pst-scrollable-table-container"><table class="autosummary longtable table autosummary"> <colgroup> <col style="width: 10%" /> <col style="width: 90%" /> </colgroup> <tbody> </tbody> </table> </div> </dd></dl> </div> <div class="docutils container"> <dl class="py class"> <dt class="sig sig-object py" id="flax.linen.GroupNorm"> <em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">flax.linen.</span></span><span class="sig-name descname"><span class="pre">GroupNorm</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">num_groups=32</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">group_size=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">epsilon=1e-06</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">param_dtype=<class</span> <span class="pre">'jax.numpy.float32'></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">use_bias=True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">use_scale=True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">bias_init=<function</span> <span class="pre">zeros></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">scale_init=<function</span> <span class="pre">ones></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">reduction_axes=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">axis_name=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">axis_index_groups=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">use_fast_variance=True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">force_float32_reductions=True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">parent=<flax.linen.module._Sentinel</span> <span class="pre">object></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">name=None</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/normalization.html#GroupNorm"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.GroupNorm" title="Permalink to this definition">#</a></dt> <dd><p>Group normalization (arxiv.org/abs/1803.08494).</p> <p>This op is similar to batch normalization, but statistics are shared across equally-sized groups of channels and not shared across batch dimension. Thus, group normalization does not depend on the batch composition and does not require maintaining internal state for storing statistics. The user should either specify the total number of channel groups or the number of channels per group.</p> <div class="admonition note"> <p class="admonition-title">Note</p> <p>LayerNorm is a special case of GroupNorm where <code class="docutils literal notranslate"><span class="pre">num_groups=1</span></code>, and InstanceNorm is a special case of GroupNorm where <code class="docutils literal notranslate"><span class="pre">group_size=1</span></code>.</p> </div> <p>Example usage:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="kn">import</span> <span class="nn">flax.linen</span> <span class="k">as</span> <span class="nn">nn</span> <span class="gp">>>> </span><span class="kn">import</span> <span class="nn">jax</span> <span class="gp">>>> </span><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> <span class="gp">>>> </span><span class="n">x</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">6</span><span class="p">))</span> <span class="gp">>>> </span><span class="n">layer</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">GroupNorm</span><span class="p">(</span><span class="n">num_groups</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">variables</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">variables</span> <span class="go">{'params': {'scale': Array([1., 1., 1., 1., 1., 1.], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0.], dtype=float32)}}</span> <span class="gp">>>> </span><span class="n">y</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">y</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">GroupNorm</span><span class="p">(</span><span class="n">num_groups</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">y2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">(</span><span class="n">reduction_axes</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">np</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_allclose</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="n">y2</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">y</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">GroupNorm</span><span class="p">(</span><span class="n">num_groups</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">group_size</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">y2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">InstanceNorm</span><span class="p">(</span><span class="n">feature_axes</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">np</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_allclose</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="n">y2</span><span class="p">)</span> </pre></div> </div> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.GroupNorm.num_groups"> <span class="sig-name descname"><span class="pre">num_groups</span></span><a class="headerlink" href="#flax.linen.GroupNorm.num_groups" title="Permalink to this definition">#</a></dt> <dd><p>the total number of channel groups. The default value of 32 is proposed by the original group normalization paper.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>int | None</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.GroupNorm.group_size"> <span class="sig-name descname"><span class="pre">group_size</span></span><a class="headerlink" href="#flax.linen.GroupNorm.group_size" title="Permalink to this definition">#</a></dt> <dd><p>the number of channels in a group.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>int | None</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.GroupNorm.epsilon"> <span class="sig-name descname"><span class="pre">epsilon</span></span><a class="headerlink" href="#flax.linen.GroupNorm.epsilon" title="Permalink to this definition">#</a></dt> <dd><p>A small float added to variance to avoid dividing by zero.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>float</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.GroupNorm.dtype"> <span class="sig-name descname"><span class="pre">dtype</span></span><a class="headerlink" href="#flax.linen.GroupNorm.dtype" title="Permalink to this definition">#</a></dt> <dd><p>the dtype of the result (default: infer from input and params).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Optional[Union[str, type[<em>Any</em>], numpy.dtype, jax._src.typing.SupportsDType, Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.GroupNorm.param_dtype"> <span class="sig-name descname"><span class="pre">param_dtype</span></span><a class="headerlink" href="#flax.linen.GroupNorm.param_dtype" title="Permalink to this definition">#</a></dt> <dd><p>the dtype passed to parameter initializers (default: float32).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[str, type[<em>Any</em>], numpy.dtype, jax._src.typing.SupportsDType, Any]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.GroupNorm.use_bias"> <span class="sig-name descname"><span class="pre">use_bias</span></span><a class="headerlink" href="#flax.linen.GroupNorm.use_bias" title="Permalink to this definition">#</a></dt> <dd><p>If True, bias (beta) is added.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>bool</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.GroupNorm.use_scale"> <span class="sig-name descname"><span class="pre">use_scale</span></span><a class="headerlink" href="#flax.linen.GroupNorm.use_scale" title="Permalink to this definition">#</a></dt> <dd><p>If True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>bool</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.GroupNorm.bias_init"> <span class="sig-name descname"><span class="pre">bias_init</span></span><a class="headerlink" href="#flax.linen.GroupNorm.bias_init" title="Permalink to this definition">#</a></dt> <dd><p>Initializer for bias, by default, zero.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.GroupNorm.scale_init"> <span class="sig-name descname"><span class="pre">scale_init</span></span><a class="headerlink" href="#flax.linen.GroupNorm.scale_init" title="Permalink to this definition">#</a></dt> <dd><p>Initializer for scale, by default, one.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.GroupNorm.reduction_axes"> <span class="sig-name descname"><span class="pre">reduction_axes</span></span><a class="headerlink" href="#flax.linen.GroupNorm.reduction_axes" title="Permalink to this definition">#</a></dt> <dd><p>List of axes used for computing normalization statistics. This list must include the final dimension, which is assumed to be the feature axis. Furthermore, if the input used at call time has additional leading axes compared to the data used for initialisation, for example due to batching, then the reduction axes need to be defined explicitly.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Optional[Union[int, collections.abc.Sequence[int]]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.GroupNorm.axis_name"> <span class="sig-name descname"><span class="pre">axis_name</span></span><a class="headerlink" href="#flax.linen.GroupNorm.axis_name" title="Permalink to this definition">#</a></dt> <dd><p>the axis name used to combine batch statistics from multiple devices. See <code class="docutils literal notranslate"><span class="pre">jax.pmap</span></code> for a description of axis names (default: None). This is only needed if the model is subdivided across devices, i.e. the array being normalized is sharded across devices within a pmap or shard map. For SPMD jit, you do not need to manually synchronize. Just make sure that the axes are correctly annotated and XLA:SPMD will insert the necessary collectives.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>str | None</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.GroupNorm.axis_index_groups"> <span class="sig-name descname"><span class="pre">axis_index_groups</span></span><a class="headerlink" href="#flax.linen.GroupNorm.axis_index_groups" title="Permalink to this definition">#</a></dt> <dd><p>groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, <code class="docutils literal notranslate"><span class="pre">[[0,</span> <span class="pre">1],</span> <span class="pre">[2,</span> <span class="pre">3]]</span></code> would independently batch-normalize over the examples on the first two and last two devices. See <code class="docutils literal notranslate"><span class="pre">jax.lax.psum</span></code> for more details.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Any</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.GroupNorm.use_fast_variance"> <span class="sig-name descname"><span class="pre">use_fast_variance</span></span><a class="headerlink" href="#flax.linen.GroupNorm.use_fast_variance" title="Permalink to this definition">#</a></dt> <dd><p>If true, use a faster, but less numerically stable, calculation for the variance.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>bool</p> </dd> </dl> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.GroupNorm.__call__"> <span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">x</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">mask</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/normalization.html#GroupNorm.__call__"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.GroupNorm.__call__" title="Permalink to this definition">#</a></dt> <dd><p>Applies group normalization to the input (arxiv.org/abs/1803.08494).</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>x</strong> – the input of shape <code class="docutils literal notranslate"><span class="pre">...C</span></code> where <code class="docutils literal notranslate"><span class="pre">C</span></code> is a channels dimension and <code class="docutils literal notranslate"><span class="pre">...</span></code> represents an arbitrary number of extra dimensions that can be used to accumulate statistics over. If no reduction axes have been specified then all additional dimensions <code class="docutils literal notranslate"><span class="pre">...</span></code> will be used to accumulate statistics apart from the leading dimension which is assumed to represent the batch.</p></li> <li><p><strong>mask</strong> – Binary array of shape broadcastable to <code class="docutils literal notranslate"><span class="pre">inputs</span></code> tensor, indicating the positions for which the mean and variance should be computed.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>Normalized inputs (the same shape as inputs).</p> </dd> </dl> </dd></dl> <p class="rubric">Methods</p> <div class="pst-scrollable-table-container"><table class="autosummary longtable table autosummary"> <colgroup> <col style="width: 10%" /> <col style="width: 90%" /> </colgroup> <tbody> </tbody> </table> </div> </dd></dl> </div> <div class="docutils container"> <dl class="py class"> <dt class="sig sig-object py" id="flax.linen.RMSNorm"> <em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">flax.linen.</span></span><span class="sig-name descname"><span class="pre">RMSNorm</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">epsilon=1e-06</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">param_dtype=<class</span> <span class="pre">'jax.numpy.float32'></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">use_scale=True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">scale_init=<function</span> <span class="pre">ones></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">reduction_axes=-1</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">feature_axes=-1</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">axis_name=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">axis_index_groups=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">use_fast_variance=True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">force_float32_reductions=True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">parent=<flax.linen.module._Sentinel</span> <span class="pre">object></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">name=None</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/normalization.html#RMSNorm"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.RMSNorm" title="Permalink to this definition">#</a></dt> <dd><p>RMS Layer normalization (<a class="reference external" href="https://arxiv.org/abs/1910.07467">https://arxiv.org/abs/1910.07467</a>).</p> <p>RMSNorm normalizes the activations of the layer for each given example in a batch independently, rather than across a batch like Batch Normalization. Unlike LayerNorm which re-centers the mean to be 0 and normalizes by the standard deviation of the activations, RMSNorm does not re-center at all and instead normalizes by the root mean square of the activations.</p> <p>Example usage:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="kn">import</span> <span class="nn">flax.linen</span> <span class="k">as</span> <span class="nn">nn</span> <span class="gp">>>> </span><span class="kn">import</span> <span class="nn">jax</span> <span class="gp">>>> </span><span class="n">x</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">6</span><span class="p">))</span> <span class="gp">>>> </span><span class="n">layer</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">RMSNorm</span><span class="p">()</span> <span class="gp">>>> </span><span class="n">variables</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">variables</span> <span class="go">{'params': {'scale': Array([1., 1., 1., 1., 1., 1.], dtype=float32)}}</span> <span class="gp">>>> </span><span class="n">y</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> </pre></div> </div> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.RMSNorm.epsilon"> <span class="sig-name descname"><span class="pre">epsilon</span></span><a class="headerlink" href="#flax.linen.RMSNorm.epsilon" title="Permalink to this definition">#</a></dt> <dd><p>A small float added to variance to avoid dividing by zero.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>float</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.RMSNorm.dtype"> <span class="sig-name descname"><span class="pre">dtype</span></span><a class="headerlink" href="#flax.linen.RMSNorm.dtype" title="Permalink to this definition">#</a></dt> <dd><p>the dtype of the result (default: infer from input and params).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Optional[Union[str, type[<em>Any</em>], numpy.dtype, jax._src.typing.SupportsDType, Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.RMSNorm.param_dtype"> <span class="sig-name descname"><span class="pre">param_dtype</span></span><a class="headerlink" href="#flax.linen.RMSNorm.param_dtype" title="Permalink to this definition">#</a></dt> <dd><p>the dtype passed to parameter initializers (default: float32).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[str, type[<em>Any</em>], numpy.dtype, jax._src.typing.SupportsDType, Any]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.RMSNorm.use_scale"> <span class="sig-name descname"><span class="pre">use_scale</span></span><a class="headerlink" href="#flax.linen.RMSNorm.use_scale" title="Permalink to this definition">#</a></dt> <dd><p>If True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>bool</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.RMSNorm.scale_init"> <span class="sig-name descname"><span class="pre">scale_init</span></span><a class="headerlink" href="#flax.linen.RMSNorm.scale_init" title="Permalink to this definition">#</a></dt> <dd><p>Initializer for scale, by default, one.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.RMSNorm.reduction_axes"> <span class="sig-name descname"><span class="pre">reduction_axes</span></span><a class="headerlink" href="#flax.linen.RMSNorm.reduction_axes" title="Permalink to this definition">#</a></dt> <dd><p>Axes for computing normalization statistics.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[int, collections.abc.Sequence[int]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.RMSNorm.feature_axes"> <span class="sig-name descname"><span class="pre">feature_axes</span></span><a class="headerlink" href="#flax.linen.RMSNorm.feature_axes" title="Permalink to this definition">#</a></dt> <dd><p>Feature axes for learned bias and scaling.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[int, collections.abc.Sequence[int]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.RMSNorm.axis_name"> <span class="sig-name descname"><span class="pre">axis_name</span></span><a class="headerlink" href="#flax.linen.RMSNorm.axis_name" title="Permalink to this definition">#</a></dt> <dd><p>the axis name used to combine batch statistics from multiple devices. See <code class="docutils literal notranslate"><span class="pre">jax.pmap</span></code> for a description of axis names (default: None). This is only needed if the model is subdivided across devices, i.e. the array being normalized is sharded across devices within a pmap or shard map. For SPMD jit, you do not need to manually synchronize. Just make sure that the axes are correctly annotated and XLA:SPMD will insert the necessary collectives.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>str | None</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.RMSNorm.axis_index_groups"> <span class="sig-name descname"><span class="pre">axis_index_groups</span></span><a class="headerlink" href="#flax.linen.RMSNorm.axis_index_groups" title="Permalink to this definition">#</a></dt> <dd><p>groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, <code class="docutils literal notranslate"><span class="pre">[[0,</span> <span class="pre">1],</span> <span class="pre">[2,</span> <span class="pre">3]]</span></code> would independently batch-normalize over the examples on the first two and last two devices. See <code class="docutils literal notranslate"><span class="pre">jax.lax.psum</span></code> for more details.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Any</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.RMSNorm.use_fast_variance"> <span class="sig-name descname"><span class="pre">use_fast_variance</span></span><a class="headerlink" href="#flax.linen.RMSNorm.use_fast_variance" title="Permalink to this definition">#</a></dt> <dd><p>If true, use a faster, but less numerically stable, calculation for the variance.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>bool</p> </dd> </dl> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.RMSNorm.__call__"> <span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">x</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">mask</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/normalization.html#RMSNorm.__call__"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.RMSNorm.__call__" title="Permalink to this definition">#</a></dt> <dd><p>Applies RMS layer normalization on the input.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>x</strong> – the inputs</p></li> <li><p><strong>mask</strong> – Binary array of shape broadcastable to <code class="docutils literal notranslate"><span class="pre">inputs</span></code> tensor, indicating the positions for which the mean and variance should be computed.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>Normalized inputs (the same shape as inputs).</p> </dd> </dl> </dd></dl> <p class="rubric">Methods</p> <div class="pst-scrollable-table-container"><table class="autosummary longtable table autosummary"> <colgroup> <col style="width: 10%" /> <col style="width: 90%" /> </colgroup> <tbody> </tbody> </table> </div> </dd></dl> </div> <div class="docutils container"> <dl class="py class"> <dt class="sig sig-object py" id="flax.linen.InstanceNorm"> <em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">flax.linen.</span></span><span class="sig-name descname"><span class="pre">InstanceNorm</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">epsilon=1e-06</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">param_dtype=<class</span> <span class="pre">'jax.numpy.float32'></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">use_bias=True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">use_scale=True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">bias_init=<function</span> <span class="pre">zeros></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">scale_init=<function</span> <span class="pre">ones></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">feature_axes=-1</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">axis_name=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">axis_index_groups=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">use_fast_variance=True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">force_float32_reductions=True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">parent=<flax.linen.module._Sentinel</span> <span class="pre">object></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">name=None</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/normalization.html#InstanceNorm"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.InstanceNorm" title="Permalink to this definition">#</a></dt> <dd><p>Instance normalization (<a class="reference external" href="https://arxiv.org/abs/1607.08022v3">https://arxiv.org/abs/1607.08022v3</a>).</p> <p>InstanceNorm normalizes the activations of the layer for each channel (rather than across all channels like Layer Normalization), and for each given example in a batch independently (rather than across an entire batch like Batch Normalization). i.e. applies a transformation that maintains the mean activation within each channel within each example close to 0 and the activation standard deviation close to 1.</p> <div class="admonition note"> <p class="admonition-title">Note</p> <p>This normalization operation is identical to LayerNorm and GroupNorm; the difference is simply which axes are reduced and the shape of the feature axes (i.e. the shape of the learnable scale and bias parameters).</p> </div> <p>Example usage:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="kn">import</span> <span class="nn">flax.linen</span> <span class="k">as</span> <span class="nn">nn</span> <span class="gp">>>> </span><span class="kn">import</span> <span class="nn">jax</span> <span class="gp">>>> </span><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> <span class="gp">>>> </span><span class="c1"># dimensions: (batch, height, width, channel)</span> <span class="gp">>>> </span><span class="n">x</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">5</span><span class="p">))</span> <span class="gp">>>> </span><span class="n">layer</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">InstanceNorm</span><span class="p">()</span> <span class="gp">>>> </span><span class="n">variables</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">variables</span> <span class="go">{'params': {'scale': Array([1., 1., 1., 1., 1.], dtype=float32), 'bias': Array([0., 0., 0., 0., 0.], dtype=float32)}}</span> <span class="gp">>>> </span><span class="n">y</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="c1"># having a channel_axis of -1 in InstanceNorm is identical to reducing all non-batch,</span> <span class="gp">>>> </span><span class="c1"># non-channel axes and using the feature_axes as the feature_axes in LayerNorm</span> <span class="gp">>>> </span><span class="n">y2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LayerNorm</span><span class="p">(</span><span class="n">reduction_axes</span><span class="o">=</span><span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">],</span> <span class="n">feature_axes</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">np</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_allclose</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="n">y2</span><span class="p">,</span> <span class="n">atol</span><span class="o">=</span><span class="mf">1e-7</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">y3</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">GroupNorm</span><span class="p">(</span><span class="n">num_groups</span><span class="o">=</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">np</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_allclose</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="n">y3</span><span class="p">,</span> <span class="n">atol</span><span class="o">=</span><span class="mf">1e-7</span><span class="p">)</span> </pre></div> </div> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.InstanceNorm.epsilon"> <span class="sig-name descname"><span class="pre">epsilon</span></span><a class="headerlink" href="#flax.linen.InstanceNorm.epsilon" title="Permalink to this definition">#</a></dt> <dd><p>A small float added to variance to avoid dividing by zero.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>float</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.InstanceNorm.dtype"> <span class="sig-name descname"><span class="pre">dtype</span></span><a class="headerlink" href="#flax.linen.InstanceNorm.dtype" title="Permalink to this definition">#</a></dt> <dd><p>the dtype of the result (default: infer from input and params).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Optional[Union[str, type[<em>Any</em>], numpy.dtype, jax._src.typing.SupportsDType, Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.InstanceNorm.param_dtype"> <span class="sig-name descname"><span class="pre">param_dtype</span></span><a class="headerlink" href="#flax.linen.InstanceNorm.param_dtype" title="Permalink to this definition">#</a></dt> <dd><p>the dtype passed to parameter initializers (default: float32).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[str, type[<em>Any</em>], numpy.dtype, jax._src.typing.SupportsDType, Any]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.InstanceNorm.use_bias"> <span class="sig-name descname"><span class="pre">use_bias</span></span><a class="headerlink" href="#flax.linen.InstanceNorm.use_bias" title="Permalink to this definition">#</a></dt> <dd><p>If True, bias (beta) is added.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>bool</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.InstanceNorm.use_scale"> <span class="sig-name descname"><span class="pre">use_scale</span></span><a class="headerlink" href="#flax.linen.InstanceNorm.use_scale" title="Permalink to this definition">#</a></dt> <dd><p>If True, multiply by scale (gamma). When the next layer is linear (also e.g. nn.relu), this can be disabled since the scaling will be done by the next layer.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>bool</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.InstanceNorm.bias_init"> <span class="sig-name descname"><span class="pre">bias_init</span></span><a class="headerlink" href="#flax.linen.InstanceNorm.bias_init" title="Permalink to this definition">#</a></dt> <dd><p>Initializer for bias, by default, zero.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.InstanceNorm.scale_init"> <span class="sig-name descname"><span class="pre">scale_init</span></span><a class="headerlink" href="#flax.linen.InstanceNorm.scale_init" title="Permalink to this definition">#</a></dt> <dd><p>Initializer for scale, by default, one.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.InstanceNorm.feature_axes"> <span class="sig-name descname"><span class="pre">feature_axes</span></span><a class="headerlink" href="#flax.linen.InstanceNorm.feature_axes" title="Permalink to this definition">#</a></dt> <dd><p>Axes for features. The learned bias and scaling parameters will be in the shape defined by the feature axes. All other axes except the batch axes (which is assumed to be the leading axis) will be reduced.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[int, collections.abc.Sequence[int]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.InstanceNorm.axis_name"> <span class="sig-name descname"><span class="pre">axis_name</span></span><a class="headerlink" href="#flax.linen.InstanceNorm.axis_name" title="Permalink to this definition">#</a></dt> <dd><p>the axis name used to combine batch statistics from multiple devices. See <code class="docutils literal notranslate"><span class="pre">jax.pmap</span></code> for a description of axis names (default: None). This is only needed if the model is subdivided across devices, i.e. the array being normalized is sharded across devices within a pmap or shard map. For SPMD jit, you do not need to manually synchronize. Just make sure that the axes are correctly annotated and XLA:SPMD will insert the necessary collectives.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>str | None</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.InstanceNorm.axis_index_groups"> <span class="sig-name descname"><span class="pre">axis_index_groups</span></span><a class="headerlink" href="#flax.linen.InstanceNorm.axis_index_groups" title="Permalink to this definition">#</a></dt> <dd><p>groups of axis indices within that named axis representing subsets of devices to reduce over (default: None). For example, <code class="docutils literal notranslate"><span class="pre">[[0,</span> <span class="pre">1],</span> <span class="pre">[2,</span> <span class="pre">3]]</span></code> would independently batch-normalize over the examples on the first two and last two devices. See <code class="docutils literal notranslate"><span class="pre">jax.lax.psum</span></code> for more details.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Any</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.InstanceNorm.use_fast_variance"> <span class="sig-name descname"><span class="pre">use_fast_variance</span></span><a class="headerlink" href="#flax.linen.InstanceNorm.use_fast_variance" title="Permalink to this definition">#</a></dt> <dd><p>If true, use a faster, but less numerically stable, calculation for the variance.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>bool</p> </dd> </dl> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.InstanceNorm.__call__"> <span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">x</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">mask</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/normalization.html#InstanceNorm.__call__"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.InstanceNorm.__call__" title="Permalink to this definition">#</a></dt> <dd><p>Applies instance normalization on the input.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>x</strong> – the inputs</p></li> <li><p><strong>mask</strong> – Binary array of shape broadcastable to <code class="docutils literal notranslate"><span class="pre">inputs</span></code> tensor, indicating the positions for which the mean and variance should be computed.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>Normalized inputs (the same shape as inputs).</p> </dd> </dl> </dd></dl> <p class="rubric">Methods</p> <div class="pst-scrollable-table-container"><table class="autosummary longtable table autosummary"> <colgroup> <col style="width: 10%" /> <col style="width: 90%" /> </colgroup> <tbody> </tbody> </table> </div> </dd></dl> </div> <div class="docutils container"> <dl class="py class"> <dt class="sig sig-object py" id="flax.linen.SpectralNorm"> <em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">flax.linen.</span></span><span class="sig-name descname"><span class="pre">SpectralNorm</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">layer_instance</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">n_steps=1</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">epsilon=1e-12</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">param_dtype=<class</span> <span class="pre">'jax.numpy.float32'></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">error_on_non_matrix=False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">collection_name='batch_stats'</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">parent=<flax.linen.module._Sentinel</span> <span class="pre">object></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">name=None</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/normalization.html#SpectralNorm"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.SpectralNorm" title="Permalink to this definition">#</a></dt> <dd><p>Spectral normalization.</p> <p>See:</p> <ul class="simple"> <li><p><a class="reference external" href="https://arxiv.org/abs/1802.05957">https://arxiv.org/abs/1802.05957</a></p></li> <li><p><a class="reference external" href="https://arxiv.org/abs/1805.08318">https://arxiv.org/abs/1805.08318</a></p></li> <li><p><a class="reference external" href="https://arxiv.org/abs/1809.11096">https://arxiv.org/abs/1809.11096</a></p></li> </ul> <p>Spectral normalization normalizes the weight params so that the spectral norm of the matrix is equal to 1. This is implemented as a layer wrapper where each wrapped layer will have its params spectral normalized before computing its <code class="docutils literal notranslate"><span class="pre">__call__</span></code> output.</p> <div class="admonition note"> <p class="admonition-title">Note</p> <p>The initialized variables dict will contain, in addition to a ‘params’ collection, a separate ‘batch_stats’ collection that will contain a <code class="docutils literal notranslate"><span class="pre">u</span></code> vector and <code class="docutils literal notranslate"><span class="pre">sigma</span></code> value, which are intermediate values used when performing spectral normalization. During training, we pass in <code class="docutils literal notranslate"><span class="pre">update_stats=True</span></code> and <code class="docutils literal notranslate"><span class="pre">mutable=['batch_stats']</span></code> so that <code class="docutils literal notranslate"><span class="pre">u</span></code> and <code class="docutils literal notranslate"><span class="pre">sigma</span></code> are updated with the most recently computed values using power iteration. This will help the power iteration method approximate the true singular value more accurately over time. During eval, we pass in <code class="docutils literal notranslate"><span class="pre">update_stats=False</span></code> to ensure we get deterministic behavior from the model.</p> </div> <p>Example usage:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="kn">import</span> <span class="nn">flax</span><span class="o">,</span> <span class="nn">flax.linen</span> <span class="k">as</span> <span class="nn">nn</span> <span class="gp">>>> </span><span class="kn">import</span> <span class="nn">jax</span><span class="o">,</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span> <span class="gp">>>> </span><span class="kn">import</span> <span class="nn">optax</span> <span class="gp">>>> </span><span class="k">class</span> <span class="nc">Foo</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="gp">... </span> <span class="nd">@nn</span><span class="o">.</span><span class="n">compact</span> <span class="gp">... </span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">train</span><span class="p">):</span> <span class="gp">... </span> <span class="n">x</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">3</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="gp">... </span> <span class="c1"># only spectral normalize the params of the second Dense layer</span> <span class="gp">... </span> <span class="n">x</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">SpectralNorm</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">4</span><span class="p">))(</span><span class="n">x</span><span class="p">,</span> <span class="n">update_stats</span><span class="o">=</span><span class="n">train</span><span class="p">)</span> <span class="gp">... </span> <span class="n">x</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">5</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="gp">... </span> <span class="k">return</span> <span class="n">x</span> <span class="gp">>>> </span><span class="c1"># init</span> <span class="gp">>>> </span><span class="n">x</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">))</span> <span class="gp">>>> </span><span class="n">y</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">5</span><span class="p">))</span> <span class="gp">>>> </span><span class="n">model</span> <span class="o">=</span> <span class="n">Foo</span><span class="p">()</span> <span class="gp">>>> </span><span class="n">variables</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">PRNGKey</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">x</span><span class="p">,</span> <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">flax</span><span class="o">.</span><span class="n">core</span><span class="o">.</span><span class="n">freeze</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">tree_util</span><span class="o">.</span><span class="n">tree_map</span><span class="p">(</span><span class="n">jnp</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">variables</span><span class="p">))</span> <span class="go">FrozenDict({</span> <span class="go"> batch_stats: {</span> <span class="go"> SpectralNorm_0: {</span> <span class="go"> Dense_1/kernel/sigma: (),</span> <span class="go"> Dense_1/kernel/u: (1, 4),</span> <span class="go"> },</span> <span class="go"> },</span> <span class="go"> params: {</span> <span class="go"> Dense_0: {</span> <span class="go"> bias: (3,),</span> <span class="go"> kernel: (2, 3),</span> <span class="go"> },</span> <span class="go"> Dense_1: {</span> <span class="go"> bias: (4,),</span> <span class="go"> kernel: (3, 4),</span> <span class="go"> },</span> <span class="go"> Dense_2: {</span> <span class="go"> bias: (5,),</span> <span class="go"> kernel: (4, 5),</span> <span class="go"> },</span> <span class="go"> },</span> <span class="go">})</span> <span class="gp">>>> </span><span class="c1"># train</span> <span class="gp">>>> </span><span class="k">def</span> <span class="nf">train_step</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">):</span> <span class="gp">... </span> <span class="k">def</span> <span class="nf">loss_fn</span><span class="p">(</span><span class="n">params</span><span class="p">):</span> <span class="gp">... </span> <span class="n">logits</span><span class="p">,</span> <span class="n">updates</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span> <span class="gp">... </span> <span class="p">{</span><span class="s1">'params'</span><span class="p">:</span> <span class="n">params</span><span class="p">,</span> <span class="s1">'batch_stats'</span><span class="p">:</span> <span class="n">variables</span><span class="p">[</span><span class="s1">'batch_stats'</span><span class="p">]},</span> <span class="gp">... </span> <span class="n">x</span><span class="p">,</span> <span class="gp">... </span> <span class="n">train</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="gp">... </span> <span class="n">mutable</span><span class="o">=</span><span class="p">[</span><span class="s1">'batch_stats'</span><span class="p">],</span> <span class="gp">... </span> <span class="p">)</span> <span class="gp">... </span> <span class="n">loss</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">optax</span><span class="o">.</span><span class="n">l2_loss</span><span class="p">(</span><span class="n">predictions</span><span class="o">=</span><span class="n">logits</span><span class="p">,</span> <span class="n">targets</span><span class="o">=</span><span class="n">y</span><span class="p">))</span> <span class="gp">... </span> <span class="k">return</span> <span class="n">loss</span><span class="p">,</span> <span class="n">updates</span> <span class="gp">...</span> <span class="gp">... </span> <span class="p">(</span><span class="n">loss</span><span class="p">,</span> <span class="n">updates</span><span class="p">),</span> <span class="n">grads</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">value_and_grad</span><span class="p">(</span><span class="n">loss_fn</span><span class="p">,</span> <span class="n">has_aux</span><span class="o">=</span><span class="kc">True</span><span class="p">)(</span> <span class="gp">... </span> <span class="n">variables</span><span class="p">[</span><span class="s1">'params'</span><span class="p">]</span> <span class="gp">... </span> <span class="p">)</span> <span class="gp">... </span> <span class="k">return</span> <span class="p">{</span> <span class="gp">... </span> <span class="s1">'params'</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">tree_util</span><span class="o">.</span><span class="n">tree_map</span><span class="p">(</span> <span class="gp">... </span> <span class="k">lambda</span> <span class="n">p</span><span class="p">,</span> <span class="n">g</span><span class="p">:</span> <span class="n">p</span> <span class="o">-</span> <span class="mf">0.1</span> <span class="o">*</span> <span class="n">g</span><span class="p">,</span> <span class="n">variables</span><span class="p">[</span><span class="s1">'params'</span><span class="p">],</span> <span class="n">grads</span> <span class="gp">... </span> <span class="p">),</span> <span class="gp">... </span> <span class="s1">'batch_stats'</span><span class="p">:</span> <span class="n">updates</span><span class="p">[</span><span class="s1">'batch_stats'</span><span class="p">],</span> <span class="gp">... </span> <span class="p">},</span> <span class="n">loss</span> <span class="gp">>>> </span><span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">10</span><span class="p">):</span> <span class="gp">... </span> <span class="n">variables</span><span class="p">,</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">train_step</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span> <span class="gp">>>> </span><span class="c1"># inference / eval</span> <span class="gp">>>> </span><span class="n">out</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> </pre></div> </div> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.SpectralNorm.layer_instance"> <span class="sig-name descname"><span class="pre">layer_instance</span></span><a class="headerlink" href="#flax.linen.SpectralNorm.layer_instance" title="Permalink to this definition">#</a></dt> <dd><p>Module instance that is wrapped with SpectralNorm</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p><a class="reference internal" href="module.html#flax.linen.Module" title="flax.linen.module.Module">flax.linen.module.Module</a></p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.SpectralNorm.n_steps"> <span class="sig-name descname"><span class="pre">n_steps</span></span><a class="headerlink" href="#flax.linen.SpectralNorm.n_steps" title="Permalink to this definition">#</a></dt> <dd><p>How many steps of power iteration to perform to approximate the singular value of the weight params.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>int</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.SpectralNorm.epsilon"> <span class="sig-name descname"><span class="pre">epsilon</span></span><a class="headerlink" href="#flax.linen.SpectralNorm.epsilon" title="Permalink to this definition">#</a></dt> <dd><p>A small float added to l2-normalization to avoid dividing by zero.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>float</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.SpectralNorm.dtype"> <span class="sig-name descname"><span class="pre">dtype</span></span><a class="headerlink" href="#flax.linen.SpectralNorm.dtype" title="Permalink to this definition">#</a></dt> <dd><p>the dtype of the result (default: infer from input and params).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Optional[Union[str, type[<em>Any</em>], numpy.dtype, jax._src.typing.SupportsDType, Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.SpectralNorm.param_dtype"> <span class="sig-name descname"><span class="pre">param_dtype</span></span><a class="headerlink" href="#flax.linen.SpectralNorm.param_dtype" title="Permalink to this definition">#</a></dt> <dd><p>the dtype passed to parameter initializers (default: float32).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[str, type[<em>Any</em>], numpy.dtype, jax._src.typing.SupportsDType, Any]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.SpectralNorm.error_on_non_matrix"> <span class="sig-name descname"><span class="pre">error_on_non_matrix</span></span><a class="headerlink" href="#flax.linen.SpectralNorm.error_on_non_matrix" title="Permalink to this definition">#</a></dt> <dd><p>Spectral normalization is only defined on matrices. By default, this module will return scalars unchanged and flatten higher-order tensors in their leading dimensions. Setting this flag to True will instead throw an error if a weight tensor with dimension greater than 2 is used by the layer.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>bool</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.SpectralNorm.collection_name"> <span class="sig-name descname"><span class="pre">collection_name</span></span><a class="headerlink" href="#flax.linen.SpectralNorm.collection_name" title="Permalink to this definition">#</a></dt> <dd><p>Name of the collection to store intermediate values used when performing spectral normalization.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>str</p> </dd> </dl> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.SpectralNorm.__call__"> <span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">update_stats</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/normalization.html#SpectralNorm.__call__"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.SpectralNorm.__call__" title="Permalink to this definition">#</a></dt> <dd><p>Compute the largest singular value of the weights in <code class="docutils literal notranslate"><span class="pre">self.layer_instance</span></code> using power iteration and normalize the weights using this value before computing the <code class="docutils literal notranslate"><span class="pre">__call__</span></code> output.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>*args</strong> – positional arguments to be passed into the call method of the underlying layer instance in <code class="docutils literal notranslate"><span class="pre">self.layer_instance</span></code>.</p></li> <li><p><strong>update_stats</strong> – if True, update the internal <code class="docutils literal notranslate"><span class="pre">u</span></code> vector and <code class="docutils literal notranslate"><span class="pre">sigma</span></code> value after computing their updated values using power iteration. This will help the power iteration method approximate the true singular value more accurately over time.</p></li> <li><p><strong>**kwargs</strong> – keyword arguments to be passed into the call method of the underlying layer instance in <code class="docutils literal notranslate"><span class="pre">self.layer_instance</span></code>.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>Output of the layer using spectral normalized weights.</p> </dd> </dl> </dd></dl> <p class="rubric">Methods</p> <div class="pst-scrollable-table-container"><table class="autosummary longtable table autosummary"> <colgroup> <col style="width: 10%" /> <col style="width: 90%" /> </colgroup> <tbody> </tbody> </table> </div> </dd></dl> </div> <div class="docutils container"> <dl class="py class"> <dt class="sig sig-object py" id="flax.linen.WeightNorm"> <em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">flax.linen.</span></span><span class="sig-name descname"><span class="pre">WeightNorm</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">layer_instance</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">epsilon=1e-12</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">param_dtype=<class</span> <span class="pre">'jax.numpy.float32'></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">use_scale=True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">scale_init=<function</span> <span class="pre">ones></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">feature_axes=-1</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">variable_filter=<factory></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">parent=<flax.linen.module._Sentinel</span> <span class="pre">object></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">name=None</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/normalization.html#WeightNorm"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.WeightNorm" title="Permalink to this definition">#</a></dt> <dd><p>L2 weight normalization (<a class="reference external" href="https://arxiv.org/abs/1602.07868">https://arxiv.org/abs/1602.07868</a>).</p> <p>Weight normalization normalizes the weight params so that the l2-norm of the matrix is equal to 1. This is implemented as a layer wrapper where each wrapped layer will have its params l2-normalized before computing its <code class="docutils literal notranslate"><span class="pre">__call__</span></code> output.</p> <p>Example usage:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="kn">import</span> <span class="nn">flax</span><span class="o">,</span> <span class="nn">flax.linen</span> <span class="k">as</span> <span class="nn">nn</span> <span class="gp">>>> </span><span class="kn">import</span> <span class="nn">jax</span><span class="o">,</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span> <span class="gp">>>> </span><span class="k">class</span> <span class="nc">Baz</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="gp">... </span> <span class="nd">@nn</span><span class="o">.</span><span class="n">compact</span> <span class="gp">... </span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="gp">... </span> <span class="k">return</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">2</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="k">class</span> <span class="nc">Bar</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="gp">... </span> <span class="nd">@nn</span><span class="o">.</span><span class="n">compact</span> <span class="gp">... </span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="gp">... </span> <span class="n">x</span> <span class="o">=</span> <span class="n">Baz</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> <span class="gp">... </span> <span class="n">x</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">3</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="gp">... </span> <span class="n">x</span> <span class="o">=</span> <span class="n">Baz</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> <span class="gp">... </span> <span class="n">x</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">3</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="gp">... </span> <span class="k">return</span> <span class="n">x</span> <span class="gp">>>> </span><span class="k">class</span> <span class="nc">Foo</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="gp">... </span> <span class="nd">@nn</span><span class="o">.</span><span class="n">compact</span> <span class="gp">... </span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="gp">... </span> <span class="n">x</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">3</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="gp">... </span> <span class="c1"># l2-normalize all params of the second Dense layer</span> <span class="gp">... </span> <span class="n">x</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">WeightNorm</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">4</span><span class="p">),</span> <span class="n">variable_filter</span><span class="o">=</span><span class="kc">None</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="gp">... </span> <span class="n">x</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">5</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="gp">... </span> <span class="c1"># l2-normalize all kernels in the Bar submodule and all params in</span> <span class="gp">... </span> <span class="c1"># the Baz submodule</span> <span class="gp">... </span> <span class="n">x</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">WeightNorm</span><span class="p">(</span><span class="n">Bar</span><span class="p">(),</span> <span class="n">variable_filter</span><span class="o">=</span><span class="p">{</span><span class="s1">'kernel'</span><span class="p">,</span> <span class="s1">'Baz'</span><span class="p">})(</span><span class="n">x</span><span class="p">)</span> <span class="gp">... </span> <span class="k">return</span> <span class="n">x</span> <span class="gp">>>> </span><span class="c1"># init</span> <span class="gp">>>> </span><span class="n">x</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">))</span> <span class="gp">>>> </span><span class="n">model</span> <span class="o">=</span> <span class="n">Foo</span><span class="p">()</span> <span class="gp">>>> </span><span class="n">variables</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">flax</span><span class="o">.</span><span class="n">core</span><span class="o">.</span><span class="n">freeze</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">tree_util</span><span class="o">.</span><span class="n">tree_map</span><span class="p">(</span><span class="n">jnp</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">variables</span><span class="p">))</span> <span class="go">FrozenDict({</span> <span class="go"> params: {</span> <span class="go"> Bar_0: {</span> <span class="go"> Baz_0: {</span> <span class="go"> Dense_0: {</span> <span class="go"> bias: (2,),</span> <span class="go"> kernel: (5, 2),</span> <span class="go"> },</span> <span class="go"> },</span> <span class="go"> Baz_1: {</span> <span class="go"> Dense_0: {</span> <span class="go"> bias: (2,),</span> <span class="go"> kernel: (3, 2),</span> <span class="go"> },</span> <span class="go"> },</span> <span class="go"> Dense_0: {</span> <span class="go"> bias: (3,),</span> <span class="go"> kernel: (2, 3),</span> <span class="go"> },</span> <span class="go"> Dense_1: {</span> <span class="go"> bias: (3,),</span> <span class="go"> kernel: (2, 3),</span> <span class="go"> },</span> <span class="go"> },</span> <span class="go"> Dense_0: {</span> <span class="go"> bias: (3,),</span> <span class="go"> kernel: (2, 3),</span> <span class="go"> },</span> <span class="go"> Dense_1: {</span> <span class="go"> bias: (4,),</span> <span class="go"> kernel: (3, 4),</span> <span class="go"> },</span> <span class="go"> Dense_2: {</span> <span class="go"> bias: (5,),</span> <span class="go"> kernel: (4, 5),</span> <span class="go"> },</span> <span class="go"> WeightNorm_0: {</span> <span class="go"> Dense_1/bias/scale: (4,),</span> <span class="go"> Dense_1/kernel/scale: (4,),</span> <span class="go"> },</span> <span class="go"> WeightNorm_1: {</span> <span class="go"> Bar_0/Baz_0/Dense_0/bias/scale: (2,),</span> <span class="go"> Bar_0/Baz_0/Dense_0/kernel/scale: (2,),</span> <span class="go"> Bar_0/Baz_1/Dense_0/bias/scale: (2,),</span> <span class="go"> Bar_0/Baz_1/Dense_0/kernel/scale: (2,),</span> <span class="go"> Bar_0/Dense_0/kernel/scale: (3,),</span> <span class="go"> Bar_0/Dense_1/kernel/scale: (3,),</span> <span class="go"> },</span> <span class="go"> },</span> <span class="go">})</span> </pre></div> </div> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.WeightNorm.layer_instance"> <span class="sig-name descname"><span class="pre">layer_instance</span></span><a class="headerlink" href="#flax.linen.WeightNorm.layer_instance" title="Permalink to this definition">#</a></dt> <dd><p>Module instance that is wrapped with WeightNorm</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p><a class="reference internal" href="module.html#flax.linen.Module" title="flax.linen.module.Module">flax.linen.module.Module</a></p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.WeightNorm.epsilon"> <span class="sig-name descname"><span class="pre">epsilon</span></span><a class="headerlink" href="#flax.linen.WeightNorm.epsilon" title="Permalink to this definition">#</a></dt> <dd><p>A small float added to l2-normalization to avoid dividing by zero.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>float</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.WeightNorm.dtype"> <span class="sig-name descname"><span class="pre">dtype</span></span><a class="headerlink" href="#flax.linen.WeightNorm.dtype" title="Permalink to this definition">#</a></dt> <dd><p>the dtype of the result (default: infer from input and params).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Optional[Union[str, type[<em>Any</em>], numpy.dtype, jax._src.typing.SupportsDType, Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.WeightNorm.param_dtype"> <span class="sig-name descname"><span class="pre">param_dtype</span></span><a class="headerlink" href="#flax.linen.WeightNorm.param_dtype" title="Permalink to this definition">#</a></dt> <dd><p>the dtype passed to parameter initializers (default: float32).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[str, type[<em>Any</em>], numpy.dtype, jax._src.typing.SupportsDType, Any]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.WeightNorm.use_scale"> <span class="sig-name descname"><span class="pre">use_scale</span></span><a class="headerlink" href="#flax.linen.WeightNorm.use_scale" title="Permalink to this definition">#</a></dt> <dd><p>If True, creates a learnable variable <code class="docutils literal notranslate"><span class="pre">scale</span></code> that is multiplied to the <code class="docutils literal notranslate"><span class="pre">layer_instance</span></code> variables after l2-normalization.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>bool</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.WeightNorm.scale_init"> <span class="sig-name descname"><span class="pre">scale_init</span></span><a class="headerlink" href="#flax.linen.WeightNorm.scale_init" title="Permalink to this definition">#</a></dt> <dd><p>Initialization function for the scaling function.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.WeightNorm.feature_axes"> <span class="sig-name descname"><span class="pre">feature_axes</span></span><a class="headerlink" href="#flax.linen.WeightNorm.feature_axes" title="Permalink to this definition">#</a></dt> <dd><p>The feature axes dimension(s). The l2-norm is calculated by reducing the <code class="docutils literal notranslate"><span class="pre">layer_instance</span></code> variables over the remaining (non-feature) axes. Therefore a separate l2-norm value is calculated and a separate scale (if <code class="docutils literal notranslate"><span class="pre">use_scale=True</span></code>) is learned for each specified feature. By default, the trailing dimension is treated as the feature axis.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Optional[Union[int, collections.abc.Sequence[int]]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.WeightNorm.variable_filter"> <span class="sig-name descname"><span class="pre">variable_filter</span></span><a class="headerlink" href="#flax.linen.WeightNorm.variable_filter" title="Permalink to this definition">#</a></dt> <dd><p>An optional iterable that contains string items. The WeightNorm layer will selectively apply l2-normalization to the <code class="docutils literal notranslate"><span class="pre">layer_instance</span></code> variables whose key path (delimited by ‘/’) has a match with <code class="docutils literal notranslate"><span class="pre">variable_filter</span></code>. For example, <code class="docutils literal notranslate"><span class="pre">variable_filter={'kernel'}</span></code> will only apply l2-normalization to variables whose key path contains ‘kernel’. By default, <code class="docutils literal notranslate"><span class="pre">variable_filter={'kernel'}</span></code>.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>collections.abc.Iterable | None</p> </dd> </dl> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.WeightNorm.__call__"> <span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/normalization.html#WeightNorm.__call__"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.WeightNorm.__call__" title="Permalink to this definition">#</a></dt> <dd><p>Compute the l2-norm of the weights in <code class="docutils literal notranslate"><span class="pre">self.layer_instance</span></code> and normalize the weights using this value before computing the <code class="docutils literal notranslate"><span class="pre">__call__</span></code> output.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>*args</strong> – positional arguments to be passed into the call method of the underlying layer instance in <code class="docutils literal notranslate"><span class="pre">self.layer_instance</span></code>.</p></li> <li><p><strong>**kwargs</strong> – keyword arguments to be passed into the call method of the underlying layer instance in <code class="docutils literal notranslate"><span class="pre">self.layer_instance</span></code>.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>Output of the layer using l2-normalized weights.</p> </dd> </dl> </dd></dl> <p class="rubric">Methods</p> <div class="pst-scrollable-table-container"><table class="autosummary longtable table autosummary"> <colgroup> <col style="width: 10%" /> <col style="width: 90%" /> </colgroup> <tbody> </tbody> </table> </div> </dd></dl> </div> </div> <div class="section" id="combinators"> <h2>Combinators<a class="headerlink" href="#combinators" title="Permalink to this heading">#</a></h2> <div class="docutils container"> <dl class="py class"> <dt class="sig sig-object py" id="flax.linen.Sequential"> <em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">flax.linen.</span></span><span class="sig-name descname"><span class="pre">Sequential</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">layers</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">parent=<flax.linen.module._Sentinel</span> <span class="pre">object></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">name=None</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/combinators.html#Sequential"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.Sequential" title="Permalink to this definition">#</a></dt> <dd><p>Applies a linear chain of Modules.</p> <p>Meant to be used only for the simple case of fusing together callables where the input of a particular module/op is the output of the previous one.</p> <p>Modules will be applied in the order that they are passed in the constructor.</p> <p>The <code class="docutils literal notranslate"><span class="pre">__call__</span></code> method of Sequential accepts any input and forwards it to the first module it contains. It chains the output sequentially to the input of the next module and returns the output of the final module.</p> <p>Example usage:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="kn">import</span> <span class="nn">flax.linen</span> <span class="k">as</span> <span class="nn">nn</span> <span class="gp">>>> </span><span class="k">class</span> <span class="nc">Foo</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="gp">... </span> <span class="nd">@nn</span><span class="o">.</span><span class="n">compact</span> <span class="gp">... </span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="gp">... </span> <span class="k">return</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">([</span><span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">4</span><span class="p">),</span> <span class="gp">... </span> <span class="n">nn</span><span class="o">.</span><span class="n">relu</span><span class="p">,</span> <span class="gp">... </span> <span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span> <span class="gp">... </span> <span class="n">nn</span><span class="o">.</span><span class="n">log_softmax</span><span class="p">])(</span><span class="n">x</span><span class="p">)</span> </pre></div> </div> <p>Since <cite>Sequential.__call__</cite> is a <cite>compact</cite> method, you can also pass functions that construct Modules inline if you need shape inference:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="n">module</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">([</span> <span class="c1"># << more layers</span> <span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">SomeModule</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])(</span><span class="n">x</span><span class="p">),</span> <span class="c1"># shape inference</span> <span class="c1"># << more layers</span> <span class="p">])</span> </pre></div> </div> <p>This combinator supports also layers that return multiple outputs if returned as a tuple or a dictionary. If the output of a layer is a <code class="docutils literal notranslate"><span class="pre">tuple</span></code> it will be expanded as <code class="docutils literal notranslate"><span class="pre">*args</span></code> in the next layer, if its a <code class="docutils literal notranslate"><span class="pre">dict</span></code> it will be expanded as <code class="docutils literal notranslate"><span class="pre">**kwargs</span></code>.</p> <p>Example usage:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="k">class</span> <span class="nc">CrossAttentionBlock</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="gp">... </span> <span class="n">num_heads</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">2</span> <span class="gp">... </span> <span class="n">qkv_features</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">16</span> <span class="gp">...</span> <span class="gp">... </span> <span class="nd">@nn</span><span class="o">.</span><span class="n">compact</span> <span class="gp">... </span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">query</span><span class="p">,</span> <span class="n">key_value</span><span class="p">):</span> <span class="gp">... </span> <span class="n">output</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">MultiHeadDotProductAttention</span><span class="p">(</span> <span class="gp">... </span> <span class="n">num_heads</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">num_heads</span><span class="p">,</span> <span class="n">qkv_features</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">qkv_features</span><span class="p">)(</span><span class="n">query</span><span class="p">,</span> <span class="gp">... </span> <span class="n">key_value</span><span class="p">)</span> <span class="gp">... </span> <span class="n">output</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">qkv_features</span><span class="p">)(</span><span class="n">output</span><span class="p">)</span> <span class="gp">... </span> <span class="k">return</span> <span class="nb">dict</span><span class="p">(</span><span class="n">query</span><span class="o">=</span><span class="n">output</span><span class="p">,</span> <span class="n">key_value</span><span class="o">=</span><span class="n">key_value</span><span class="p">)</span> <span class="c1"># also works for tuples</span> <span class="gp">>>> </span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Sequence</span> <span class="gp">>>> </span><span class="k">class</span> <span class="nc">CrossAttentionNetwork</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="gp">... </span> <span class="n">num_layers</span><span class="p">:</span> <span class="n">Sequence</span><span class="p">[</span><span class="nb">int</span><span class="p">]</span> <span class="gp">...</span> <span class="gp">... </span> <span class="nd">@nn</span><span class="o">.</span><span class="n">compact</span> <span class="gp">... </span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="gp">... </span> <span class="k">return</span> <span class="n">nn</span><span class="o">.</span><span class="n">Sequential</span><span class="p">([</span><span class="n">CrossAttentionBlock</span><span class="p">()</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="gp">... </span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">)])(</span><span class="n">query</span><span class="p">,</span> <span class="n">key_value</span><span class="p">)</span> </pre></div> </div> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.Sequential.layers"> <span class="sig-name descname"><span class="pre">layers</span></span><a class="headerlink" href="#flax.linen.Sequential.layers" title="Permalink to this definition">#</a></dt> <dd><p>A sequence of callables to be applied in order.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>collections.abc.Sequence[collections.abc.Callable[[…], Any]]</p> </dd> </dl> </dd></dl> <dl class="field-list simple"> <dt class="field-odd">Raises</dt> <dd class="field-odd"><p><strong>ValueError</strong> – If layers is not a sequence.</p> </dd> </dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.Sequential.__call__"> <span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/combinators.html#Sequential.__call__"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.Sequential.__call__" title="Permalink to this definition">#</a></dt> <dd><p>Call self as a function.</p> </dd></dl> <p class="rubric">Methods</p> <div class="pst-scrollable-table-container"><table class="autosummary longtable table autosummary"> <colgroup> <col style="width: 10%" /> <col style="width: 90%" /> </colgroup> <tbody> </tbody> </table> </div> </dd></dl> </div> </div> <div class="section" id="stochastic"> <h2>Stochastic<a class="headerlink" href="#stochastic" title="Permalink to this heading">#</a></h2> <div class="docutils container"> <dl class="py class"> <dt class="sig sig-object py" id="flax.linen.Dropout"> <em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">flax.linen.</span></span><span class="sig-name descname"><span class="pre">Dropout</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">rate</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">broadcast_dims=()</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">deterministic=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">rng_collection='dropout'</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">parent=<flax.linen.module._Sentinel</span> <span class="pre">object></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">name=None</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/stochastic.html#Dropout"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.Dropout" title="Permalink to this definition">#</a></dt> <dd><p>Create a dropout layer.</p> <div class="admonition note"> <p class="admonition-title">Note</p> <p>When using <a class="reference internal" href="module.html#flax.linen.Module.apply" title="flax.linen.Module.apply"><code class="xref py py-meth docutils literal notranslate"><span class="pre">Module.apply()</span></code></a>, make sure to include an RNG seed named <code class="docutils literal notranslate"><span class="pre">'dropout'</span></code>. Dropout isn’t necessary for variable initialization.</p> </div> <p>Example usage:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="kn">import</span> <span class="nn">flax.linen</span> <span class="k">as</span> <span class="nn">nn</span> <span class="gp">>>> </span><span class="kn">import</span> <span class="nn">jax</span><span class="o">,</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span> <span class="gp">>>> </span><span class="k">class</span> <span class="nc">MLP</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="gp">... </span> <span class="nd">@nn</span><span class="o">.</span><span class="n">compact</span> <span class="gp">... </span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">train</span><span class="p">):</span> <span class="gp">... </span> <span class="n">x</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">4</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="gp">... </span> <span class="n">x</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="mf">0.5</span><span class="p">,</span> <span class="n">deterministic</span><span class="o">=</span><span class="ow">not</span> <span class="n">train</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="gp">... </span> <span class="k">return</span> <span class="n">x</span> <span class="gp">>>> </span><span class="n">model</span> <span class="o">=</span> <span class="n">MLP</span><span class="p">()</span> <span class="gp">>>> </span><span class="n">x</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span> <span class="gp">>>> </span><span class="n">variables</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">x</span><span class="p">,</span> <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> <span class="c1"># don't use dropout</span> <span class="gp">>>> </span><span class="n">model</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> <span class="c1"># don't use dropout</span> <span class="go">Array([[-0.88686204, -0.5928178 , -0.5184689 , -0.4345976 ]], dtype=float32)</span> <span class="gp">>>> </span><span class="n">model</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">train</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">rngs</span><span class="o">=</span><span class="p">{</span><span class="s1">'dropout'</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">1</span><span class="p">)})</span> <span class="c1"># use dropout</span> <span class="go">Array([[ 0. , -1.1856356, -1.0369378, 0. ]], dtype=float32)</span> </pre></div> </div> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.Dropout.rate"> <span class="sig-name descname"><span class="pre">rate</span></span><a class="headerlink" href="#flax.linen.Dropout.rate" title="Permalink to this definition">#</a></dt> <dd><p>the dropout probability. (_not_ the keep rate!)</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>float</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.Dropout.broadcast_dims"> <span class="sig-name descname"><span class="pre">broadcast_dims</span></span><a class="headerlink" href="#flax.linen.Dropout.broadcast_dims" title="Permalink to this definition">#</a></dt> <dd><p>dimensions that will share the same dropout mask</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>collections.abc.Sequence[int]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.Dropout.deterministic"> <span class="sig-name descname"><span class="pre">deterministic</span></span><a class="headerlink" href="#flax.linen.Dropout.deterministic" title="Permalink to this definition">#</a></dt> <dd><p>if false the inputs are scaled by <code class="docutils literal notranslate"><span class="pre">1</span> <span class="pre">/</span> <span class="pre">(1</span> <span class="pre">-</span> <span class="pre">rate)</span></code> and masked, whereas if true, no mask is applied and the inputs are returned as is.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>bool | None</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.Dropout.rng_collection"> <span class="sig-name descname"><span class="pre">rng_collection</span></span><a class="headerlink" href="#flax.linen.Dropout.rng_collection" title="Permalink to this definition">#</a></dt> <dd><p>the rng collection name to use when requesting an rng key.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>str</p> </dd> </dl> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.Dropout.__call__"> <span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">inputs</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">deterministic</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">rng</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/stochastic.html#Dropout.__call__"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.Dropout.__call__" title="Permalink to this definition">#</a></dt> <dd><p>Applies a random dropout mask to the input.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>inputs</strong> – the inputs that should be randomly masked.</p></li> <li><p><strong>deterministic</strong> – if false the inputs are scaled by <code class="docutils literal notranslate"><span class="pre">1</span> <span class="pre">/</span> <span class="pre">(1</span> <span class="pre">-</span> <span class="pre">rate)</span></code> and masked, whereas if true, no mask is applied and the inputs are returned as is.</p></li> <li><p><strong>rng</strong> – an optional PRNGKey used as the random key, if not specified, one will be generated using <code class="docutils literal notranslate"><span class="pre">make_rng</span></code> with the <code class="docutils literal notranslate"><span class="pre">rng_collection</span></code> name.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>The masked inputs reweighted to preserve mean.</p> </dd> </dl> </dd></dl> <p class="rubric">Methods</p> <div class="pst-scrollable-table-container"><table class="autosummary longtable table autosummary"> <colgroup> <col style="width: 10%" /> <col style="width: 90%" /> </colgroup> <tbody> </tbody> </table> </div> </dd></dl> </div> </div> <div class="section" id="attention"> <h2>Attention<a class="headerlink" href="#attention" title="Permalink to this heading">#</a></h2> <div class="docutils container"> <dl class="py class"> <dt class="sig sig-object py" id="flax.linen.MultiHeadDotProductAttention"> <em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">flax.linen.</span></span><span class="sig-name descname"><span class="pre">MultiHeadDotProductAttention</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">num_heads</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">param_dtype=<class</span> <span class="pre">'jax.numpy.float32'></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">qkv_features=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">out_features=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">broadcast_dropout=True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dropout_rate=0.0</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">deterministic=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">precision=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">kernel_init=<function</span> <span class="pre">variance_scaling.<locals>.init></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">out_kernel_init=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">bias_init=<function</span> <span class="pre">zeros></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">out_bias_init=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">use_bias=True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">attention_fn=<function</span> <span class="pre">dot_product_attention></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">decode=False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">normalize_qk=False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">force_fp32_for_softmax=False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">qkv_dot_general=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">out_dot_general=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">qkv_dot_general_cls=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">out_dot_general_cls=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">qk_attn_weights_einsum_cls=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">attn_weights_value_einsum_cls=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">parent=<flax.linen.module._Sentinel</span> <span class="pre">object></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">name=None</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/attention.html#MultiHeadDotProductAttention"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.MultiHeadDotProductAttention" title="Permalink to this definition">#</a></dt> <dd><p>Multi-head dot-product attention.</p> <p>Example usage:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="kn">import</span> <span class="nn">flax.linen</span> <span class="k">as</span> <span class="nn">nn</span> <span class="gp">>>> </span><span class="kn">import</span> <span class="nn">jax</span> <span class="gp">>>> </span><span class="n">layer</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">MultiHeadDotProductAttention</span><span class="p">(</span><span class="n">num_heads</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> <span class="n">qkv_features</span><span class="o">=</span><span class="mi">16</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">key1</span><span class="p">,</span> <span class="n">key2</span><span class="p">,</span> <span class="n">key3</span><span class="p">,</span> <span class="n">key4</span><span class="p">,</span> <span class="n">key5</span><span class="p">,</span> <span class="n">key6</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="mi">6</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">shape</span> <span class="o">=</span> <span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">5</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">key1</span><span class="p">,</span> <span class="n">shape</span><span class="p">),</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">key2</span><span class="p">,</span> <span class="n">shape</span><span class="p">),</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">key3</span><span class="p">,</span> <span class="n">shape</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">variables</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">q</span><span class="p">)</span> <span class="gp">>>> </span><span class="c1"># different inputs for inputs_q, inputs_k and inputs_v</span> <span class="gp">>>> </span><span class="n">out</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span> <span class="gp">>>> </span><span class="c1"># equivalent to layer.apply(variables, inputs_q=q, inputs_k=k, inputs_v=k)</span> <span class="gp">>>> </span><span class="n">out</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">)</span> <span class="gp">>>> </span><span class="c1"># equivalent to layer.apply(variables, inputs_q=q, inputs_k=q) and layer.apply(variables, inputs_q=q, inputs_k=q, inputs_v=q)</span> <span class="gp">>>> </span><span class="n">out</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">q</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">attention_kwargs</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">(</span> <span class="gp">... </span> <span class="n">num_heads</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> <span class="gp">... </span> <span class="n">qkv_features</span><span class="o">=</span><span class="mi">16</span><span class="p">,</span> <span class="gp">... </span> <span class="n">kernel_init</span><span class="o">=</span><span class="n">nn</span><span class="o">.</span><span class="n">initializers</span><span class="o">.</span><span class="n">ones</span><span class="p">,</span> <span class="gp">... </span> <span class="n">bias_init</span><span class="o">=</span><span class="n">nn</span><span class="o">.</span><span class="n">initializers</span><span class="o">.</span><span class="n">zeros</span><span class="p">,</span> <span class="gp">... </span> <span class="n">dropout_rate</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span> <span class="gp">... </span> <span class="n">deterministic</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="gp">... </span> <span class="p">)</span> <span class="gp">>>> </span><span class="k">class</span> <span class="nc">Module</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="gp">... </span> <span class="n">attention_kwargs</span><span class="p">:</span> <span class="nb">dict</span> <span class="gp">...</span> <span class="gp">... </span> <span class="nd">@nn</span><span class="o">.</span><span class="n">compact</span> <span class="gp">... </span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">dropout_rng</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="gp">... </span> <span class="n">out1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">MultiHeadDotProductAttention</span><span class="p">(</span><span class="o">**</span><span class="bp">self</span><span class="o">.</span><span class="n">attention_kwargs</span><span class="p">)(</span><span class="n">x</span><span class="p">,</span> <span class="n">dropout_rng</span><span class="o">=</span><span class="n">dropout_rng</span><span class="p">)</span> <span class="gp">... </span> <span class="n">out2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">MultiHeadDotProductAttention</span><span class="p">(</span><span class="o">**</span><span class="bp">self</span><span class="o">.</span><span class="n">attention_kwargs</span><span class="p">)(</span><span class="n">x</span><span class="p">,</span> <span class="n">dropout_rng</span><span class="o">=</span><span class="n">dropout_rng</span><span class="p">)</span> <span class="gp">... </span> <span class="k">return</span> <span class="n">out1</span><span class="p">,</span> <span class="n">out2</span> <span class="gp">>>> </span><span class="n">module</span> <span class="o">=</span> <span class="n">Module</span><span class="p">(</span><span class="n">attention_kwargs</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">variables</span> <span class="o">=</span> <span class="n">module</span><span class="o">.</span><span class="n">init</span><span class="p">({</span><span class="s1">'params'</span><span class="p">:</span> <span class="n">key1</span><span class="p">,</span> <span class="s1">'dropout'</span><span class="p">:</span> <span class="n">key2</span><span class="p">},</span> <span class="n">q</span><span class="p">)</span> <span class="gp">>>> </span><span class="c1"># out1 and out2 are different.</span> <span class="gp">>>> </span><span class="n">out1</span><span class="p">,</span> <span class="n">out2</span> <span class="o">=</span> <span class="n">module</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">rngs</span><span class="o">=</span><span class="p">{</span><span class="s1">'dropout'</span><span class="p">:</span> <span class="n">key3</span><span class="p">})</span> <span class="gp">>>> </span><span class="c1"># out3 and out4 are different.</span> <span class="gp">>>> </span><span class="c1"># out1 and out3 are different. out2 and out4 are different.</span> <span class="gp">>>> </span><span class="n">out3</span><span class="p">,</span> <span class="n">out4</span> <span class="o">=</span> <span class="n">module</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">rngs</span><span class="o">=</span><span class="p">{</span><span class="s1">'dropout'</span><span class="p">:</span> <span class="n">key4</span><span class="p">})</span> <span class="gp">>>> </span><span class="c1"># out1 and out2 are the same.</span> <span class="gp">>>> </span><span class="n">out1</span><span class="p">,</span> <span class="n">out2</span> <span class="o">=</span> <span class="n">module</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">dropout_rng</span><span class="o">=</span><span class="n">key5</span><span class="p">)</span> <span class="gp">>>> </span><span class="c1"># out1 and out2 are the same as out3 and out4.</span> <span class="gp">>>> </span><span class="c1"># providing a `dropout_rng` arg will take precedence over the `rngs` arg in `.apply`</span> <span class="gp">>>> </span><span class="n">out3</span><span class="p">,</span> <span class="n">out4</span> <span class="o">=</span> <span class="n">module</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">rngs</span><span class="o">=</span><span class="p">{</span><span class="s1">'dropout'</span><span class="p">:</span> <span class="n">key6</span><span class="p">},</span> <span class="n">dropout_rng</span><span class="o">=</span><span class="n">key5</span><span class="p">)</span> </pre></div> </div> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.MultiHeadDotProductAttention.num_heads"> <span class="sig-name descname"><span class="pre">num_heads</span></span><a class="headerlink" href="#flax.linen.MultiHeadDotProductAttention.num_heads" title="Permalink to this definition">#</a></dt> <dd><p>Number of attention heads. Features (i.e. inputs_q.shape[-1]) should be divisible by the number of heads.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>int</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.MultiHeadDotProductAttention.dtype"> <span class="sig-name descname"><span class="pre">dtype</span></span><a class="headerlink" href="#flax.linen.MultiHeadDotProductAttention.dtype" title="Permalink to this definition">#</a></dt> <dd><p>The dtype of the computation (default: infer from inputs and params)</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Optional[Union[str, type[<em>Any</em>], numpy.dtype, jax._src.typing.SupportsDType, Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.MultiHeadDotProductAttention.param_dtype"> <span class="sig-name descname"><span class="pre">param_dtype</span></span><a class="headerlink" href="#flax.linen.MultiHeadDotProductAttention.param_dtype" title="Permalink to this definition">#</a></dt> <dd><p>The dtype passed to parameter initializers (default: float32)</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[str, type[<em>Any</em>], numpy.dtype, jax._src.typing.SupportsDType, Any]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.MultiHeadDotProductAttention.qkv_features"> <span class="sig-name descname"><span class="pre">qkv_features</span></span><a class="headerlink" href="#flax.linen.MultiHeadDotProductAttention.qkv_features" title="Permalink to this definition">#</a></dt> <dd><p>Dimension of the key, query, and value.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>int | None</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.MultiHeadDotProductAttention.out_features"> <span class="sig-name descname"><span class="pre">out_features</span></span><a class="headerlink" href="#flax.linen.MultiHeadDotProductAttention.out_features" title="Permalink to this definition">#</a></dt> <dd><p>Dimension of the last projection</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>int | None</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.MultiHeadDotProductAttention.broadcast_dropout"> <span class="sig-name descname"><span class="pre">broadcast_dropout</span></span><a class="headerlink" href="#flax.linen.MultiHeadDotProductAttention.broadcast_dropout" title="Permalink to this definition">#</a></dt> <dd><p>Use a broadcasted dropout along batch dims.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>bool</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.MultiHeadDotProductAttention.dropout_rate"> <span class="sig-name descname"><span class="pre">dropout_rate</span></span><a class="headerlink" href="#flax.linen.MultiHeadDotProductAttention.dropout_rate" title="Permalink to this definition">#</a></dt> <dd><p>Dropout rate.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>float</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.MultiHeadDotProductAttention.deterministic"> <span class="sig-name descname"><span class="pre">deterministic</span></span><a class="headerlink" href="#flax.linen.MultiHeadDotProductAttention.deterministic" title="Permalink to this definition">#</a></dt> <dd><p>If False, the attention weight is masked randomly using dropout, whereas if True, the attention weights are deterministic.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>bool | None</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.MultiHeadDotProductAttention.precision"> <span class="sig-name descname"><span class="pre">precision</span></span><a class="headerlink" href="#flax.linen.MultiHeadDotProductAttention.precision" title="Permalink to this definition">#</a></dt> <dd><p>Numerical precision of the computation see <code class="docutils literal notranslate"><span class="pre">jax.lax.Precision</span></code> for details.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[None, str, jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.MultiHeadDotProductAttention.kernel_init"> <span class="sig-name descname"><span class="pre">kernel_init</span></span><a class="headerlink" href="#flax.linen.MultiHeadDotProductAttention.kernel_init" title="Permalink to this definition">#</a></dt> <dd><p>Initializer for the kernel of the Dense layers.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.MultiHeadDotProductAttention.out_kernel_init"> <span class="sig-name descname"><span class="pre">out_kernel_init</span></span><a class="headerlink" href="#flax.linen.MultiHeadDotProductAttention.out_kernel_init" title="Permalink to this definition">#</a></dt> <dd><p>Optional Initializer for the kernel of the output Dense layer, if None, <code class="docutils literal notranslate"><span class="pre">kernel_init</span></code> will be used.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Optional[Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.MultiHeadDotProductAttention.bias_init"> <span class="sig-name descname"><span class="pre">bias_init</span></span><a class="headerlink" href="#flax.linen.MultiHeadDotProductAttention.bias_init" title="Permalink to this definition">#</a></dt> <dd><p>Initializer for the bias of the Dense layers.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.MultiHeadDotProductAttention.out_bias_init"> <span class="sig-name descname"><span class="pre">out_bias_init</span></span><a class="headerlink" href="#flax.linen.MultiHeadDotProductAttention.out_bias_init" title="Permalink to this definition">#</a></dt> <dd><p>Optional Initializer for the bias of the output Dense layer, if None, <code class="docutils literal notranslate"><span class="pre">bias_init</span></code> will be used.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Optional[Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.MultiHeadDotProductAttention.use_bias"> <span class="sig-name descname"><span class="pre">use_bias</span></span><a class="headerlink" href="#flax.linen.MultiHeadDotProductAttention.use_bias" title="Permalink to this definition">#</a></dt> <dd><p>Whether pointwise QKVO dense transforms use bias.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>bool</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.MultiHeadDotProductAttention.attention_fn"> <span class="sig-name descname"><span class="pre">attention_fn</span></span><a class="headerlink" href="#flax.linen.MultiHeadDotProductAttention.attention_fn" title="Permalink to this definition">#</a></dt> <dd><p>dot_product_attention or compatible function. Accepts query, key, value, and returns output of shape <code class="docutils literal notranslate"><span class="pre">[bs,</span> <span class="pre">dim1,</span> <span class="pre">dim2,</span> <span class="pre">...,</span> <span class="pre">dimN,,</span> <span class="pre">num_heads,</span> <span class="pre">value_channels]</span></code></p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>collections.abc.Callable[[…], Union[jax.Array, Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.MultiHeadDotProductAttention.decode"> <span class="sig-name descname"><span class="pre">decode</span></span><a class="headerlink" href="#flax.linen.MultiHeadDotProductAttention.decode" title="Permalink to this definition">#</a></dt> <dd><p>Whether to prepare and use an autoregressive cache.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>bool</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.MultiHeadDotProductAttention.normalize_qk"> <span class="sig-name descname"><span class="pre">normalize_qk</span></span><a class="headerlink" href="#flax.linen.MultiHeadDotProductAttention.normalize_qk" title="Permalink to this definition">#</a></dt> <dd><p>Should QK normalization be applied (arxiv.org/abs/2302.05442).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>bool</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.MultiHeadDotProductAttention.qk_attn_weights_einsum_cls"> <span class="sig-name descname"><span class="pre">qk_attn_weights_einsum_cls</span></span><a class="headerlink" href="#flax.linen.MultiHeadDotProductAttention.qk_attn_weights_einsum_cls" title="Permalink to this definition">#</a></dt> <dd><p>factory function to create the einsum for computing the attention weights.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>collections.abc.Callable[[…], collections.abc.Callable[[…], Union[jax.Array, Any]]] | None</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.MultiHeadDotProductAttention.attn_weights_value_einsum_cls"> <span class="sig-name descname"><span class="pre">attn_weights_value_einsum_cls</span></span><a class="headerlink" href="#flax.linen.MultiHeadDotProductAttention.attn_weights_value_einsum_cls" title="Permalink to this definition">#</a></dt> <dd><p>factory function to create the einsum for computing the product of the attention weights and the values.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>collections.abc.Callable[[…], collections.abc.Callable[[…], Union[jax.Array, Any]]] | None</p> </dd> </dl> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.MultiHeadDotProductAttention.__call__"> <span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">inputs_q</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">inputs_k</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">inputs_v</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">inputs_kv</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">mask</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">deterministic</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dropout_rng</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">sow_weights</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/attention.html#MultiHeadDotProductAttention.__call__"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.MultiHeadDotProductAttention.__call__" title="Permalink to this definition">#</a></dt> <dd><p>Applies multi-head dot product attention on the input data.</p> <p>Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector.</p> <p>If both inputs_k and inputs_v are None, they will both copy the value of inputs_q (self attention). If only inputs_v is None, it will copy the value of inputs_k.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>inputs_q</strong> – input queries of shape <code class="docutils literal notranslate"><span class="pre">[batch_sizes...,</span> <span class="pre">length,</span> <span class="pre">features]</span></code>.</p></li> <li><p><strong>inputs_k</strong> – key of shape <code class="docutils literal notranslate"><span class="pre">[batch_sizes...,</span> <span class="pre">length,</span> <span class="pre">features]</span></code>. If None, inputs_k will copy the value of inputs_q.</p></li> <li><p><strong>inputs_v</strong> – values of shape <code class="docutils literal notranslate"><span class="pre">[batch_sizes...,</span> <span class="pre">length,</span> <span class="pre">features]</span></code>. If None, inputs_v will copy the value of inputs_k.</p></li> <li><p><strong>inputs_kv</strong> – key/values of shape <code class="docutils literal notranslate"><span class="pre">[batch_sizes...,</span> <span class="pre">length,</span> <span class="pre">features]</span></code>. If None, inputs_kv will copy the value of inputs_q. This arg will be deprecated soon. Use inputs_k and inputs_v instead.</p></li> <li><p><strong>mask</strong> – attention mask of shape <code class="docutils literal notranslate"><span class="pre">[batch_sizes...,</span> <span class="pre">num_heads,</span> <span class="pre">query_length,</span> <span class="pre">key/value_length]</span></code>. Attention weights are masked out if their corresponding mask value is <code class="docutils literal notranslate"><span class="pre">False</span></code>.</p></li> <li><p><strong>deterministic</strong> – if false, the attention weight is masked randomly using dropout, whereas if true, the attention weights are deterministic.</p></li> <li><p><strong>dropout_rng</strong> – optional rng key to pass to the attention layer’s dropout mask. Otherwise, self.make_rng(‘dropout’) is used instead.</p></li> <li><p><strong>sow_weights</strong> – if <code class="docutils literal notranslate"><span class="pre">True</span></code>, the attention weights are sowed into the ‘intermediates’ collection. Remember to mark ‘intermediates’ as mutable via <code class="docutils literal notranslate"><span class="pre">mutable=['intermediates']</span></code> in order to have that collection returned.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>output of shape <code class="docutils literal notranslate"><span class="pre">[batch_sizes...,</span> <span class="pre">length,</span> <span class="pre">features]</span></code>.</p> </dd> </dl> </dd></dl> <p class="rubric">Methods</p> <div class="pst-scrollable-table-container"><table class="autosummary longtable table autosummary"> <colgroup> <col style="width: 10%" /> <col style="width: 90%" /> </colgroup> <tbody> </tbody> </table> </div> </dd></dl> </div> <div class="docutils container"> <dl class="py class"> <dt class="sig sig-object py" id="flax.linen.MultiHeadAttention"> <em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">flax.linen.</span></span><span class="sig-name descname"><span class="pre">MultiHeadAttention</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">num_heads</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">param_dtype=<class</span> <span class="pre">'jax.numpy.float32'></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">qkv_features=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">out_features=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">broadcast_dropout=True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dropout_rate=0.0</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">deterministic=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">precision=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">kernel_init=<function</span> <span class="pre">variance_scaling.<locals>.init></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">out_kernel_init=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">bias_init=<function</span> <span class="pre">zeros></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">out_bias_init=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">use_bias=True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">attention_fn=<function</span> <span class="pre">dot_product_attention></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">decode=False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">normalize_qk=False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">force_fp32_for_softmax=False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">qkv_dot_general=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">out_dot_general=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">qkv_dot_general_cls=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">out_dot_general_cls=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">qk_attn_weights_einsum_cls=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">attn_weights_value_einsum_cls=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">parent=<flax.linen.module._Sentinel</span> <span class="pre">object></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">name=None</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/attention.html#MultiHeadAttention"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.MultiHeadAttention" title="Permalink to this definition">#</a></dt> <dd><p>Multi-head dot-product attention. Alias for <code class="docutils literal notranslate"><span class="pre">MultiHeadDotProductAttention</span></code>.</p> <p><strong>NOTE</strong>: <code class="docutils literal notranslate"><span class="pre">MultiHeadAttention</span></code> is a wrapper of <code class="docutils literal notranslate"><span class="pre">MultiHeadDotProductAttention</span></code>, and so their implementations are identical. However <code class="docutils literal notranslate"><span class="pre">MultiHeadAttention</span></code> layers will, by default, be named <code class="docutils literal notranslate"><span class="pre">MultiHeadAttention_{index}</span></code>, whereas <code class="docutils literal notranslate"><span class="pre">MultiHeadDotProductAttention</span></code> will be named <code class="docutils literal notranslate"><span class="pre">MultiHeadDotProductAttention_{index}</span></code>. Therefore, this could affect checkpointing, param collection names and RNG threading (since the layer name is used when generating new RNG’s) within the module.</p> <p>Example usage:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="kn">import</span> <span class="nn">flax.linen</span> <span class="k">as</span> <span class="nn">nn</span> <span class="gp">>>> </span><span class="kn">import</span> <span class="nn">jax</span> <span class="gp">>>> </span><span class="n">layer</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">MultiHeadAttention</span><span class="p">(</span><span class="n">num_heads</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> <span class="n">qkv_features</span><span class="o">=</span><span class="mi">16</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">key1</span><span class="p">,</span> <span class="n">key2</span><span class="p">,</span> <span class="n">key3</span><span class="p">,</span> <span class="n">key4</span><span class="p">,</span> <span class="n">key5</span><span class="p">,</span> <span class="n">key6</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="mi">6</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">shape</span> <span class="o">=</span> <span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">5</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">key1</span><span class="p">,</span> <span class="n">shape</span><span class="p">),</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">key2</span><span class="p">,</span> <span class="n">shape</span><span class="p">),</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">key3</span><span class="p">,</span> <span class="n">shape</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">variables</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">q</span><span class="p">)</span> <span class="gp">>>> </span><span class="c1"># different inputs for inputs_q, inputs_k and inputs_v</span> <span class="gp">>>> </span><span class="n">out</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span><span class="p">)</span> <span class="gp">>>> </span><span class="c1"># equivalent to layer.apply(variables, inputs_q=q, inputs_k=k, inputs_v=k)</span> <span class="gp">>>> </span><span class="n">out</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">)</span> <span class="gp">>>> </span><span class="c1"># equivalent to layer.apply(variables, inputs_q=q, inputs_k=q) and layer.apply(variables, inputs_q=q, inputs_k=q, inputs_v=q)</span> <span class="gp">>>> </span><span class="n">out</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">q</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">attention_kwargs</span> <span class="o">=</span> <span class="nb">dict</span><span class="p">(</span> <span class="gp">... </span> <span class="n">num_heads</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> <span class="gp">... </span> <span class="n">qkv_features</span><span class="o">=</span><span class="mi">16</span><span class="p">,</span> <span class="gp">... </span> <span class="n">kernel_init</span><span class="o">=</span><span class="n">nn</span><span class="o">.</span><span class="n">initializers</span><span class="o">.</span><span class="n">ones</span><span class="p">,</span> <span class="gp">... </span> <span class="n">bias_init</span><span class="o">=</span><span class="n">nn</span><span class="o">.</span><span class="n">initializers</span><span class="o">.</span><span class="n">zeros</span><span class="p">,</span> <span class="gp">... </span> <span class="n">dropout_rate</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span> <span class="gp">... </span> <span class="n">deterministic</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="gp">... </span> <span class="p">)</span> <span class="gp">>>> </span><span class="k">class</span> <span class="nc">Module</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="gp">... </span> <span class="n">attention_kwargs</span><span class="p">:</span> <span class="nb">dict</span> <span class="gp">...</span> <span class="gp">... </span> <span class="nd">@nn</span><span class="o">.</span><span class="n">compact</span> <span class="gp">... </span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">dropout_rng</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="gp">... </span> <span class="n">out1</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">MultiHeadAttention</span><span class="p">(</span><span class="o">**</span><span class="bp">self</span><span class="o">.</span><span class="n">attention_kwargs</span><span class="p">)(</span><span class="n">x</span><span class="p">,</span> <span class="n">dropout_rng</span><span class="o">=</span><span class="n">dropout_rng</span><span class="p">)</span> <span class="gp">... </span> <span class="n">out2</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">MultiHeadAttention</span><span class="p">(</span><span class="o">**</span><span class="bp">self</span><span class="o">.</span><span class="n">attention_kwargs</span><span class="p">)(</span><span class="n">x</span><span class="p">,</span> <span class="n">dropout_rng</span><span class="o">=</span><span class="n">dropout_rng</span><span class="p">)</span> <span class="gp">... </span> <span class="k">return</span> <span class="n">out1</span><span class="p">,</span> <span class="n">out2</span> <span class="gp">>>> </span><span class="n">module</span> <span class="o">=</span> <span class="n">Module</span><span class="p">(</span><span class="n">attention_kwargs</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">variables</span> <span class="o">=</span> <span class="n">module</span><span class="o">.</span><span class="n">init</span><span class="p">({</span><span class="s1">'params'</span><span class="p">:</span> <span class="n">key1</span><span class="p">,</span> <span class="s1">'dropout'</span><span class="p">:</span> <span class="n">key2</span><span class="p">},</span> <span class="n">q</span><span class="p">)</span> <span class="gp">>>> </span><span class="c1"># out1 and out2 are different.</span> <span class="gp">>>> </span><span class="n">out1</span><span class="p">,</span> <span class="n">out2</span> <span class="o">=</span> <span class="n">module</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">rngs</span><span class="o">=</span><span class="p">{</span><span class="s1">'dropout'</span><span class="p">:</span> <span class="n">key3</span><span class="p">})</span> <span class="gp">>>> </span><span class="c1"># out3 and out4 are different.</span> <span class="gp">>>> </span><span class="c1"># out1 and out3 are different. out2 and out4 are different.</span> <span class="gp">>>> </span><span class="n">out3</span><span class="p">,</span> <span class="n">out4</span> <span class="o">=</span> <span class="n">module</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">rngs</span><span class="o">=</span><span class="p">{</span><span class="s1">'dropout'</span><span class="p">:</span> <span class="n">key4</span><span class="p">})</span> <span class="gp">>>> </span><span class="c1"># out1 and out2 are the same.</span> <span class="gp">>>> </span><span class="n">out1</span><span class="p">,</span> <span class="n">out2</span> <span class="o">=</span> <span class="n">module</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">dropout_rng</span><span class="o">=</span><span class="n">key5</span><span class="p">)</span> <span class="gp">>>> </span><span class="c1"># out1 and out2 are the same as out3 and out4.</span> <span class="gp">>>> </span><span class="c1"># providing a `dropout_rng` arg will take precedence over the `rngs` arg in `.apply`</span> <span class="gp">>>> </span><span class="n">out3</span><span class="p">,</span> <span class="n">out4</span> <span class="o">=</span> <span class="n">module</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">q</span><span class="p">,</span> <span class="n">rngs</span><span class="o">=</span><span class="p">{</span><span class="s1">'dropout'</span><span class="p">:</span> <span class="n">key6</span><span class="p">},</span> <span class="n">dropout_rng</span><span class="o">=</span><span class="n">key5</span><span class="p">)</span> </pre></div> </div> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.MultiHeadAttention.num_heads"> <span class="sig-name descname"><span class="pre">num_heads</span></span><a class="headerlink" href="#flax.linen.MultiHeadAttention.num_heads" title="Permalink to this definition">#</a></dt> <dd><p>number of attention heads. Features (i.e. inputs_q.shape[-1]) should be divisible by the number of heads.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>int</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.MultiHeadAttention.dtype"> <span class="sig-name descname"><span class="pre">dtype</span></span><a class="headerlink" href="#flax.linen.MultiHeadAttention.dtype" title="Permalink to this definition">#</a></dt> <dd><p>the dtype of the computation (default: infer from inputs and params)</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Optional[Union[str, type[<em>Any</em>], numpy.dtype, jax._src.typing.SupportsDType, Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.MultiHeadAttention.param_dtype"> <span class="sig-name descname"><span class="pre">param_dtype</span></span><a class="headerlink" href="#flax.linen.MultiHeadAttention.param_dtype" title="Permalink to this definition">#</a></dt> <dd><p>the dtype passed to parameter initializers (default: float32)</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[str, type[<em>Any</em>], numpy.dtype, jax._src.typing.SupportsDType, Any]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.MultiHeadAttention.qkv_features"> <span class="sig-name descname"><span class="pre">qkv_features</span></span><a class="headerlink" href="#flax.linen.MultiHeadAttention.qkv_features" title="Permalink to this definition">#</a></dt> <dd><p>dimension of the key, query, and value.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>int | None</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.MultiHeadAttention.out_features"> <span class="sig-name descname"><span class="pre">out_features</span></span><a class="headerlink" href="#flax.linen.MultiHeadAttention.out_features" title="Permalink to this definition">#</a></dt> <dd><p>dimension of the last projection</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>int | None</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.MultiHeadAttention.broadcast_dropout"> <span class="sig-name descname"><span class="pre">broadcast_dropout</span></span><a class="headerlink" href="#flax.linen.MultiHeadAttention.broadcast_dropout" title="Permalink to this definition">#</a></dt> <dd><p>bool: use a broadcasted dropout along batch dims.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>bool</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.MultiHeadAttention.dropout_rate"> <span class="sig-name descname"><span class="pre">dropout_rate</span></span><a class="headerlink" href="#flax.linen.MultiHeadAttention.dropout_rate" title="Permalink to this definition">#</a></dt> <dd><p>dropout rate</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>float</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.MultiHeadAttention.deterministic"> <span class="sig-name descname"><span class="pre">deterministic</span></span><a class="headerlink" href="#flax.linen.MultiHeadAttention.deterministic" title="Permalink to this definition">#</a></dt> <dd><p>if false, the attention weight is masked randomly using dropout, whereas if true, the attention weights are deterministic.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>bool | None</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.MultiHeadAttention.precision"> <span class="sig-name descname"><span class="pre">precision</span></span><a class="headerlink" href="#flax.linen.MultiHeadAttention.precision" title="Permalink to this definition">#</a></dt> <dd><p>numerical precision of the computation see <code class="docutils literal notranslate"><span class="pre">jax.lax.Precision</span></code> for details.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[None, str, jax._src.lax.lax.Precision, tuple[str, str], tuple[jax._src.lax.lax.Precision, jax._src.lax.lax.Precision]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.MultiHeadAttention.kernel_init"> <span class="sig-name descname"><span class="pre">kernel_init</span></span><a class="headerlink" href="#flax.linen.MultiHeadAttention.kernel_init" title="Permalink to this definition">#</a></dt> <dd><p>initializer for the kernel of the Dense layers.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.MultiHeadAttention.bias_init"> <span class="sig-name descname"><span class="pre">bias_init</span></span><a class="headerlink" href="#flax.linen.MultiHeadAttention.bias_init" title="Permalink to this definition">#</a></dt> <dd><p>initializer for the bias of the Dense layers.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.MultiHeadAttention.use_bias"> <span class="sig-name descname"><span class="pre">use_bias</span></span><a class="headerlink" href="#flax.linen.MultiHeadAttention.use_bias" title="Permalink to this definition">#</a></dt> <dd><p>bool: whether pointwise QKVO dense transforms use bias.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>bool</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.MultiHeadAttention.attention_fn"> <span class="sig-name descname"><span class="pre">attention_fn</span></span><a class="headerlink" href="#flax.linen.MultiHeadAttention.attention_fn" title="Permalink to this definition">#</a></dt> <dd><p>dot_product_attention or compatible function. Accepts query, key, value, and returns output of shape <code class="docutils literal notranslate"><span class="pre">[bs,</span> <span class="pre">dim1,</span> <span class="pre">dim2,</span> <span class="pre">...,</span> <span class="pre">dimN,,</span> <span class="pre">num_heads,</span> <span class="pre">value_channels]</span></code></p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>collections.abc.Callable[[…], Union[jax.Array, Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.MultiHeadAttention.decode"> <span class="sig-name descname"><span class="pre">decode</span></span><a class="headerlink" href="#flax.linen.MultiHeadAttention.decode" title="Permalink to this definition">#</a></dt> <dd><p>whether to prepare and use an autoregressive cache.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>bool</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.MultiHeadAttention.normalize_qk"> <span class="sig-name descname"><span class="pre">normalize_qk</span></span><a class="headerlink" href="#flax.linen.MultiHeadAttention.normalize_qk" title="Permalink to this definition">#</a></dt> <dd><p>should QK normalization be applied (arxiv.org/abs/2302.05442).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>bool</p> </dd> </dl> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.MultiHeadAttention.__call__"> <span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">inputs_q</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">inputs_k</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">inputs_v</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">inputs_kv</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">mask</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">deterministic</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dropout_rng</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">sow_weights</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#flax.linen.MultiHeadAttention.__call__" title="Permalink to this definition">#</a></dt> <dd><p>Applies multi-head dot product attention on the input data.</p> <p>Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector.</p> <p>If both inputs_k and inputs_v are None, they will both copy the value of inputs_q (self attention). If only inputs_v is None, it will copy the value of inputs_k.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>inputs_q</strong> – input queries of shape <code class="docutils literal notranslate"><span class="pre">[batch_sizes...,</span> <span class="pre">length,</span> <span class="pre">features]</span></code>.</p></li> <li><p><strong>inputs_k</strong> – key of shape <code class="docutils literal notranslate"><span class="pre">[batch_sizes...,</span> <span class="pre">length,</span> <span class="pre">features]</span></code>. If None, inputs_k will copy the value of inputs_q.</p></li> <li><p><strong>inputs_v</strong> – values of shape <code class="docutils literal notranslate"><span class="pre">[batch_sizes...,</span> <span class="pre">length,</span> <span class="pre">features]</span></code>. If None, inputs_v will copy the value of inputs_k.</p></li> <li><p><strong>inputs_kv</strong> – key/values of shape <code class="docutils literal notranslate"><span class="pre">[batch_sizes...,</span> <span class="pre">length,</span> <span class="pre">features]</span></code>. If None, inputs_kv will copy the value of inputs_q. This arg will be deprecated soon. Use inputs_k and inputs_v instead.</p></li> <li><p><strong>mask</strong> – attention mask of shape <code class="docutils literal notranslate"><span class="pre">[batch_sizes...,</span> <span class="pre">num_heads,</span> <span class="pre">query_length,</span> <span class="pre">key/value_length]</span></code>. Attention weights are masked out if their corresponding mask value is <code class="docutils literal notranslate"><span class="pre">False</span></code>.</p></li> <li><p><strong>deterministic</strong> – if false, the attention weight is masked randomly using dropout, whereas if true, the attention weights are deterministic.</p></li> <li><p><strong>dropout_rng</strong> – optional rng key to pass to the attention layer’s dropout mask. Otherwise, self.make_rng(‘dropout’) is used instead.</p></li> <li><p><strong>sow_weights</strong> – if <code class="docutils literal notranslate"><span class="pre">True</span></code>, the attention weights are sowed into the ‘intermediates’ collection. Remember to mark ‘intermediates’ as mutable via <code class="docutils literal notranslate"><span class="pre">mutable=['intermediates']</span></code> in order to have that collection returned.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>output of shape <code class="docutils literal notranslate"><span class="pre">[batch_sizes...,</span> <span class="pre">length,</span> <span class="pre">features]</span></code>.</p> </dd> </dl> </dd></dl> <p class="rubric">Methods</p> <div class="pst-scrollable-table-container"><table class="autosummary longtable table autosummary"> <colgroup> <col style="width: 10%" /> <col style="width: 90%" /> </colgroup> <tbody> </tbody> </table> </div> </dd></dl> </div> <div class="docutils container"> <dl class="py class"> <dt class="sig sig-object py" id="flax.linen.SelfAttention"> <em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">flax.linen.</span></span><span class="sig-name descname"><span class="pre">SelfAttention</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">num_heads</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">param_dtype=<class</span> <span class="pre">'jax.numpy.float32'></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">qkv_features=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">out_features=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">broadcast_dropout=True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dropout_rate=0.0</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">deterministic=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">precision=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">kernel_init=<function</span> <span class="pre">variance_scaling.<locals>.init></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">out_kernel_init=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">bias_init=<function</span> <span class="pre">zeros></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">out_bias_init=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">use_bias=True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">attention_fn=<function</span> <span class="pre">dot_product_attention></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">decode=False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">normalize_qk=False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">force_fp32_for_softmax=False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">qkv_dot_general=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">out_dot_general=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">qkv_dot_general_cls=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">out_dot_general_cls=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">qk_attn_weights_einsum_cls=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">attn_weights_value_einsum_cls=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">parent=<flax.linen.module._Sentinel</span> <span class="pre">object></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">name=None</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/attention.html#SelfAttention"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.SelfAttention" title="Permalink to this definition">#</a></dt> <dd><p>Self-attention special case of multi-head dot-product attention. This layer is deprecated in favor of <code class="docutils literal notranslate"><span class="pre">MultiHeadDotProductAttention</span></code>.</p> <dl> <dt>Example usage::</dt><dd><div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="kn">import</span> <span class="nn">flax.linen</span> <span class="k">as</span> <span class="nn">nn</span> <span class="gp">>>> </span><span class="kn">import</span> <span class="nn">jax</span><span class="o">,</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span> <span class="gp">>>> </span><span class="n">layer</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">MultiHeadDotProductAttention</span><span class="p">(</span><span class="n">num_heads</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> <span class="n">qkv_features</span><span class="o">=</span><span class="mi">16</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">variables</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">4</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">5</span><span class="p">)))</span> </pre></div> </div> </dd> </dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.SelfAttention.__call__"> <span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">inputs_q</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">mask</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">deterministic</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dropout_rng</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">sow_weights</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/attention.html#SelfAttention.__call__"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.SelfAttention.__call__" title="Permalink to this definition">#</a></dt> <dd><p>Applies multi-head dot product self-attention on the input data.</p> <p>Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>inputs_q</strong> – input queries of shape <code class="docutils literal notranslate"><span class="pre">[batch_sizes...,</span> <span class="pre">length,</span> <span class="pre">features]</span></code>.</p></li> <li><p><strong>mask</strong> – attention mask of shape <code class="docutils literal notranslate"><span class="pre">[batch_sizes...,</span> <span class="pre">num_heads,</span> <span class="pre">query_length,</span> <span class="pre">key/value_length]</span></code>. Attention weights are masked out if their corresponding mask value is <code class="docutils literal notranslate"><span class="pre">False</span></code>.</p></li> <li><p><strong>deterministic</strong> – if false, the attention weight is masked randomly using dropout, whereas if true, the attention weights are deterministic.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>output of shape <code class="docutils literal notranslate"><span class="pre">[batch_sizes...,</span> <span class="pre">length,</span> <span class="pre">features]</span></code>.</p> </dd> </dl> </dd></dl> <p class="rubric">Methods</p> <div class="pst-scrollable-table-container"><table class="autosummary longtable table autosummary"> <colgroup> <col style="width: 10%" /> <col style="width: 90%" /> </colgroup> <tbody> </tbody> </table> </div> </dd></dl> </div> <dl class="py function"> <dt class="sig sig-object py" id="flax.linen.dot_product_attention_weights"> <span class="sig-prename descclassname"><span class="pre">flax.linen.</span></span><span class="sig-name descname"><span class="pre">dot_product_attention_weights</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">query</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">key</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">bias</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">mask</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">broadcast_dropout</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dropout_rng</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dropout_rate</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.0</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">deterministic</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">precision</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">module</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">force_fp32_for_softmax</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">einsum_dot_general</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">einsum</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/attention.html#dot_product_attention_weights"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.dot_product_attention_weights" title="Permalink to this definition">#</a></dt> <dd><p>Computes dot-product attention weights given query and key.</p> <p>Used by <a class="reference internal" href="#flax.linen.dot_product_attention" title="flax.linen.dot_product_attention"><code class="xref py py-func docutils literal notranslate"><span class="pre">dot_product_attention()</span></code></a>, which is what you’ll most likely use. But if you want access to the attention weights for introspection, then you can directly call this function and call einsum yourself.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>query</strong> – queries for calculating attention with shape of <code class="docutils literal notranslate"><span class="pre">[batch...,</span> <span class="pre">q_length,</span> <span class="pre">num_heads,</span> <span class="pre">qk_depth_per_head]</span></code>.</p></li> <li><p><strong>key</strong> – keys for calculating attention with shape of <code class="docutils literal notranslate"><span class="pre">[batch...,</span> <span class="pre">kv_length,</span> <span class="pre">num_heads,</span> <span class="pre">qk_depth_per_head]</span></code>.</p></li> <li><p><strong>bias</strong> – bias for the attention weights. This should be broadcastable to the shape <code class="docutils literal notranslate"><span class="pre">[batch...,</span> <span class="pre">num_heads,</span> <span class="pre">q_length,</span> <span class="pre">kv_length]</span></code>. This can be used for incorporating causal masks, padding masks, proximity bias, etc.</p></li> <li><p><strong>mask</strong> – mask for the attention weights. This should be broadcastable to the shape <code class="docutils literal notranslate"><span class="pre">[batch...,</span> <span class="pre">num_heads,</span> <span class="pre">q_length,</span> <span class="pre">kv_length]</span></code>. This can be used for incorporating causal masks. Attention weights are masked out if their corresponding mask value is <code class="docutils literal notranslate"><span class="pre">False</span></code>.</p></li> <li><p><strong>broadcast_dropout</strong> – bool: use a broadcasted dropout along batch dims.</p></li> <li><p><strong>dropout_rng</strong> – JAX PRNGKey: to be used for dropout</p></li> <li><p><strong>dropout_rate</strong> – dropout rate</p></li> <li><p><strong>deterministic</strong> – bool, deterministic or not (to apply dropout)</p></li> <li><p><strong>dtype</strong> – the dtype of the computation (default: infer from inputs and params)</p></li> <li><p><strong>precision</strong> – numerical precision of the computation see <code class="docutils literal notranslate"><span class="pre">jax.lax.Precision</span></code> for details.</p></li> <li><p><strong>module</strong> – the Module that will sow the attention weights into the ‘intermediates’ collection. Remember to mark ‘intermediates’ as mutable via <code class="docutils literal notranslate"><span class="pre">mutable=['intermediates']</span></code> in order to have that collection returned. If <code class="docutils literal notranslate"><span class="pre">module</span></code> is None, the attention weights will not be sowed.</p></li> <li><p><strong>force_fp32_for_softmax</strong> – bool, whether to force the softmax to be computed in fp32. This is useful for mixed-precision training where higher precision is desired for numerical stability.</p></li> <li><p><strong>einsum_dot_general</strong> – the dot_general to use in einsum.</p></li> <li><p><strong>einsum</strong> – If unspecified, default <cite>jnp.einsum</cite> will be used. This argument is mutually exclusive with <cite>precision</cite> and <cite>einsum_dot_general</cite>.</p></li> </ul> </dd> <dt class="field-even">Raises</dt> <dd class="field-even"><p><strong>ValueError</strong> – if both <cite>precision</cite>/<cite>einsum_dot_general</cite> and <cite>einsum</cite> are specified.</p> </dd> <dt class="field-odd">Returns</dt> <dd class="field-odd"><p>Output of shape <code class="docutils literal notranslate"><span class="pre">[batch...,</span> <span class="pre">num_heads,</span> <span class="pre">q_length,</span> <span class="pre">kv_length]</span></code>.</p> </dd> </dl> </dd></dl> <dl class="py function"> <dt class="sig sig-object py" id="flax.linen.dot_product_attention"> <span class="sig-prename descclassname"><span class="pre">flax.linen.</span></span><span class="sig-name descname"><span class="pre">dot_product_attention</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">query</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">key</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">value</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">bias</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">mask</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">broadcast_dropout</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dropout_rng</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dropout_rate</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">0.0</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">deterministic</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">precision</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">module</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">force_fp32_for_softmax</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">einsum_dot_general</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">qk_attn_weights_einsum</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">attn_weights_value_einsum</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/attention.html#dot_product_attention"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.dot_product_attention" title="Permalink to this definition">#</a></dt> <dd><p>Computes dot-product attention given query, key, and value.</p> <p>This is the core function for applying attention based on <a class="reference external" href="https://arxiv.org/abs/1706.03762">https://arxiv.org/abs/1706.03762</a>. It calculates the attention weights given query and key and combines the values using the attention weights.</p> <div class="admonition note"> <p class="admonition-title">Note</p> <p><code class="docutils literal notranslate"><span class="pre">query</span></code>, <code class="docutils literal notranslate"><span class="pre">key</span></code>, <code class="docutils literal notranslate"><span class="pre">value</span></code> needn’t have any batch dimensions.</p> </div> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>query</strong> – queries for calculating attention with shape of <code class="docutils literal notranslate"><span class="pre">[batch...,</span> <span class="pre">q_length,</span> <span class="pre">num_heads,</span> <span class="pre">qk_depth_per_head]</span></code>.</p></li> <li><p><strong>key</strong> – keys for calculating attention with shape of <code class="docutils literal notranslate"><span class="pre">[batch...,</span> <span class="pre">kv_length,</span> <span class="pre">num_heads,</span> <span class="pre">qk_depth_per_head]</span></code>.</p></li> <li><p><strong>value</strong> – values to be used in attention with shape of <code class="docutils literal notranslate"><span class="pre">[batch...,</span> <span class="pre">kv_length,</span> <span class="pre">num_heads,</span> <span class="pre">v_depth_per_head]</span></code>.</p></li> <li><p><strong>bias</strong> – bias for the attention weights. This should be broadcastable to the shape <code class="docutils literal notranslate"><span class="pre">[batch...,</span> <span class="pre">num_heads,</span> <span class="pre">q_length,</span> <span class="pre">kv_length]</span></code>. This can be used for incorporating causal masks, padding masks, proximity bias, etc.</p></li> <li><p><strong>mask</strong> – mask for the attention weights. This should be broadcastable to the shape <code class="docutils literal notranslate"><span class="pre">[batch...,</span> <span class="pre">num_heads,</span> <span class="pre">q_length,</span> <span class="pre">kv_length]</span></code>. This can be used for incorporating causal masks. Attention weights are masked out if their corresponding mask value is <code class="docutils literal notranslate"><span class="pre">False</span></code>.</p></li> <li><p><strong>broadcast_dropout</strong> – bool: use a broadcasted dropout along batch dims.</p></li> <li><p><strong>dropout_rng</strong> – JAX PRNGKey: to be used for dropout</p></li> <li><p><strong>dropout_rate</strong> – dropout rate</p></li> <li><p><strong>deterministic</strong> – bool, deterministic or not (to apply dropout)</p></li> <li><p><strong>dtype</strong> – the dtype of the computation (default: infer from inputs)</p></li> <li><p><strong>precision</strong> – numerical precision of the computation see <a href="#id3"><span class="problematic" id="id4">``</span></a>jax.lax.Precision` for details.</p></li> <li><p><strong>module</strong> – the Module that will sow the attention weights into the ‘intermediates’ collection. Remember to mark ‘intermediates’ as mutable via <code class="docutils literal notranslate"><span class="pre">mutable=['intermediates']</span></code> in order to have that collection returned. If <code class="docutils literal notranslate"><span class="pre">module</span></code> is None, the attention weights will not be sowed.</p></li> <li><p><strong>force_fp32_for_softmax</strong> – bool, whether to force the softmax to be computed in fp32. This is useful for mixed-precision training where higher precision is desired for numerical stability.</p></li> <li><p><strong>einsum_dot_general</strong> – the dot_general to use in <cite>jnp.einsum</cite>.</p></li> <li><p><strong>qk_attn_weights_einsum</strong> – the einsum for computing the attention weights. When unspecified, the default <cite>jnp.einsum</cite> will be used. This argument is mutually exclusive with <cite>precision</cite> and <cite>einsum_dot_general</cite>.</p></li> <li><p><strong>attn_weights_value_einsum</strong> – the einsum for computing the product of the attention weights and the values. When unspecified, the default <cite>jnp.einsum</cite> will be used. This argument is mutually exclusive with <cite>precision</cite> and <cite>einsum_dot_general</cite>.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>Output of shape <code class="docutils literal notranslate"><span class="pre">[batch...,</span> <span class="pre">q_length,</span> <span class="pre">num_heads,</span> <span class="pre">v_depth_per_head]</span></code>.</p> </dd> <dt class="field-odd">Raises</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>ValueError</strong> – if both <cite>precision</cite>/<cite>einsum_dot_general</cite> and</p></li> <li><p><strong>qk_attn_weights_einsum</strong> – specified.</p></li> </ul> </dd> </dl> </dd></dl> <dl class="py function"> <dt class="sig sig-object py" id="flax.linen.make_attention_mask"> <span class="sig-prename descclassname"><span class="pre">flax.linen.</span></span><span class="sig-name descname"><span class="pre">make_attention_mask</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">query_input</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">key_input</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">pairwise_fn=<jnp.ufunc</span> <span class="pre">'multiply'></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">extra_batch_dims=0</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype=<class</span> <span class="pre">'jax.numpy.float32'></span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/attention.html#make_attention_mask"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.make_attention_mask" title="Permalink to this definition">#</a></dt> <dd><p>Mask-making helper for attention weights.</p> <p>In case of 1d inputs (i.e., <code class="docutils literal notranslate"><span class="pre">[batch...,</span> <span class="pre">len_q]</span></code>, <code class="docutils literal notranslate"><span class="pre">[batch...,</span> <span class="pre">len_kv]</span></code>, the attention weights will be <code class="docutils literal notranslate"><span class="pre">[batch...,</span> <span class="pre">heads,</span> <span class="pre">len_q,</span> <span class="pre">len_kv]</span></code> and this function will produce <code class="docutils literal notranslate"><span class="pre">[batch...,</span> <span class="pre">1,</span> <span class="pre">len_q,</span> <span class="pre">len_kv]</span></code>.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>query_input</strong> – a batched, flat input of query_length size</p></li> <li><p><strong>key_input</strong> – a batched, flat input of key_length size</p></li> <li><p><strong>pairwise_fn</strong> – broadcasting elementwise comparison function</p></li> <li><p><strong>extra_batch_dims</strong> – number of extra batch dims to add singleton axes for, none by default</p></li> <li><p><strong>dtype</strong> – mask return dtype</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>A <code class="docutils literal notranslate"><span class="pre">[batch...,</span> <span class="pre">1,</span> <span class="pre">len_q,</span> <span class="pre">len_kv]</span></code> shaped mask for 1d attention.</p> </dd> </dl> </dd></dl> <dl class="py function"> <dt class="sig sig-object py" id="flax.linen.make_causal_mask"> <span class="sig-prename descclassname"><span class="pre">flax.linen.</span></span><span class="sig-name descname"><span class="pre">make_causal_mask</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">x</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">extra_batch_dims=0</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype=<class</span> <span class="pre">'jax.numpy.float32'></span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/attention.html#make_causal_mask"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.make_causal_mask" title="Permalink to this definition">#</a></dt> <dd><p>Make a causal mask for self-attention.</p> <p>In case of 1d inputs (i.e., <code class="docutils literal notranslate"><span class="pre">[batch...,</span> <span class="pre">len]</span></code>, the self-attention weights will be <code class="docutils literal notranslate"><span class="pre">[batch...,</span> <span class="pre">heads,</span> <span class="pre">len,</span> <span class="pre">len]</span></code> and this function will produce a causal mask of shape <code class="docutils literal notranslate"><span class="pre">[batch...,</span> <span class="pre">1,</span> <span class="pre">len,</span> <span class="pre">len]</span></code>.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>x</strong> – input array of shape <code class="docutils literal notranslate"><span class="pre">[batch...,</span> <span class="pre">len]</span></code></p></li> <li><p><strong>extra_batch_dims</strong> – number of batch dims to add singleton axes for, none by default</p></li> <li><p><strong>dtype</strong> – mask return dtype</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>A <code class="docutils literal notranslate"><span class="pre">[batch...,</span> <span class="pre">1,</span> <span class="pre">len,</span> <span class="pre">len]</span></code> shaped causal mask for 1d attention.</p> </dd> </dl> </dd></dl> </div> <div class="section" id="recurrent"> <h2>Recurrent<a class="headerlink" href="#recurrent" title="Permalink to this heading">#</a></h2> <div class="docutils container"> <dl class="py class"> <dt class="sig sig-object py" id="flax.linen.RNNCellBase"> <em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">flax.linen.</span></span><span class="sig-name descname"><span class="pre">RNNCellBase</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">parent=<flax.linen.module._Sentinel</span> <span class="pre">object></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">name=None</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/recurrent.html#RNNCellBase"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.RNNCellBase" title="Permalink to this definition">#</a></dt> <dd><p>RNN cell base class.</p> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.RNNCellBase.__call__"> <span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#flax.linen.RNNCellBase.__call__" title="Permalink to this definition">#</a></dt> <dd><p>Call self as a function.</p> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.RNNCellBase.initialize_carry"> <span class="sig-name descname"><span class="pre">initialize_carry</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">rng</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">input_shape</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/recurrent.html#RNNCellBase.initialize_carry"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.RNNCellBase.initialize_carry" title="Permalink to this definition">#</a></dt> <dd><p>Initialize the RNN cell carry.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>rng</strong> – random number generator passed to the init_fn.</p></li> <li><p><strong>input_shape</strong> – a tuple providing the shape of the input to the cell.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>An initialized carry for the given RNN cell.</p> </dd> </dl> </dd></dl> <p class="rubric">Methods</p> <div class="pst-scrollable-table-container"><table class="autosummary longtable table autosummary"> <colgroup> <col style="width: 10%" /> <col style="width: 90%" /> </colgroup> <tbody> <tr class="row-odd"><td><p><a class="reference internal" href="#flax.linen.RNNCellBase.initialize_carry" title="flax.linen.RNNCellBase.initialize_carry"><code class="xref py py-obj docutils literal notranslate"><span class="pre">initialize_carry</span></code></a>(rng, input_shape)</p></td> <td><p>Initialize the RNN cell carry.</p></td> </tr> </tbody> </table> </div> </dd></dl> </div> <div class="docutils container"> <dl class="py class"> <dt class="sig sig-object py" id="flax.linen.LSTMCell"> <em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">flax.linen.</span></span><span class="sig-name descname"><span class="pre">LSTMCell</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">features</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">gate_fn=<PjitFunction</span> <span class="pre">of</span> <span class="pre"><function</span> <span class="pre">sigmoid>></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">activation_fn=<PjitFunction</span> <span class="pre">of</span> <span class="pre"><function</span> <span class="pre">tanh>></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">kernel_init=<function</span> <span class="pre">variance_scaling.<locals>.init></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">recurrent_kernel_init=<function</span> <span class="pre">orthogonal.<locals>.init></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">bias_init=<function</span> <span class="pre">zeros></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">param_dtype=<class</span> <span class="pre">'jax.numpy.float32'></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">carry_init=<function</span> <span class="pre">zeros></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">parent=<flax.linen.module._Sentinel</span> <span class="pre">object></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">name=None</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/recurrent.html#LSTMCell"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.LSTMCell" title="Permalink to this definition">#</a></dt> <dd><p>LSTM cell.</p> <p>The mathematical definition of the cell is as follows</p> <div class="math notranslate nohighlight"> \[\begin{split}\begin{array}{ll} i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\ f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\ g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\ o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\ c' = f * c + i * g \\ h' = o * \tanh(c') \\ \end{array}\end{split}\]</div> <p>where x is the input, h is the output of the previous time step, and c is the memory.</p> <p>Example usage:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="kn">import</span> <span class="nn">flax.linen</span> <span class="k">as</span> <span class="nn">nn</span> <span class="gp">>>> </span><span class="kn">import</span> <span class="nn">jax</span><span class="o">,</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span> <span class="gp">>>> </span><span class="n">x</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span> <span class="gp">>>> </span><span class="n">layer</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">LSTMCell</span><span class="p">(</span><span class="n">features</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">carry</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">initialize_carry</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">variables</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span> <span class="n">carry</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">new_carry</span><span class="p">,</span> <span class="n">out</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">carry</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> </pre></div> </div> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.LSTMCell.features"> <span class="sig-name descname"><span class="pre">features</span></span><a class="headerlink" href="#flax.linen.LSTMCell.features" title="Permalink to this definition">#</a></dt> <dd><p>number of output features.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>int</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.LSTMCell.gate_fn"> <span class="sig-name descname"><span class="pre">gate_fn</span></span><a class="headerlink" href="#flax.linen.LSTMCell.gate_fn" title="Permalink to this definition">#</a></dt> <dd><p>activation function used for gates (default: sigmoid).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>collections.abc.Callable[[…], Any]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.LSTMCell.activation_fn"> <span class="sig-name descname"><span class="pre">activation_fn</span></span><a class="headerlink" href="#flax.linen.LSTMCell.activation_fn" title="Permalink to this definition">#</a></dt> <dd><p>activation function used for output and memory update (default: tanh).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>collections.abc.Callable[[…], Any]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.LSTMCell.kernel_init"> <span class="sig-name descname"><span class="pre">kernel_init</span></span><a class="headerlink" href="#flax.linen.LSTMCell.kernel_init" title="Permalink to this definition">#</a></dt> <dd><p>initializer function for the kernels that transform the input (default: lecun_normal).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.LSTMCell.recurrent_kernel_init"> <span class="sig-name descname"><span class="pre">recurrent_kernel_init</span></span><a class="headerlink" href="#flax.linen.LSTMCell.recurrent_kernel_init" title="Permalink to this definition">#</a></dt> <dd><p>initializer function for the kernels that transform the hidden state (default: initializers.orthogonal()).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.LSTMCell.bias_init"> <span class="sig-name descname"><span class="pre">bias_init</span></span><a class="headerlink" href="#flax.linen.LSTMCell.bias_init" title="Permalink to this definition">#</a></dt> <dd><p>initializer for the bias parameters (default: initializers.zeros_init())</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.LSTMCell.dtype"> <span class="sig-name descname"><span class="pre">dtype</span></span><a class="headerlink" href="#flax.linen.LSTMCell.dtype" title="Permalink to this definition">#</a></dt> <dd><p>the dtype of the computation (default: infer from inputs and params).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Optional[Union[str, type[<em>Any</em>], numpy.dtype, jax._src.typing.SupportsDType, Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.LSTMCell.param_dtype"> <span class="sig-name descname"><span class="pre">param_dtype</span></span><a class="headerlink" href="#flax.linen.LSTMCell.param_dtype" title="Permalink to this definition">#</a></dt> <dd><p>the dtype passed to parameter initializers (default: float32).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[str, type[<em>Any</em>], numpy.dtype, jax._src.typing.SupportsDType, Any]</p> </dd> </dl> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.LSTMCell.__call__"> <span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">carry</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">inputs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/recurrent.html#LSTMCell.__call__"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.LSTMCell.__call__" title="Permalink to this definition">#</a></dt> <dd><p>A long short-term memory (LSTM) cell.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>carry</strong> – the hidden state of the LSTM cell, initialized using <code class="docutils literal notranslate"><span class="pre">LSTMCell.initialize_carry</span></code>.</p></li> <li><p><strong>inputs</strong> – an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>A tuple with the new carry and the output.</p> </dd> </dl> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.LSTMCell.initialize_carry"> <span class="sig-name descname"><span class="pre">initialize_carry</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">rng</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">input_shape</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/recurrent.html#LSTMCell.initialize_carry"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.LSTMCell.initialize_carry" title="Permalink to this definition">#</a></dt> <dd><p>Initialize the RNN cell carry.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>rng</strong> – random number generator passed to the init_fn.</p></li> <li><p><strong>input_shape</strong> – a tuple providing the shape of the input to the cell.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>An initialized carry for the given RNN cell.</p> </dd> </dl> </dd></dl> <p class="rubric">Methods</p> <div class="pst-scrollable-table-container"><table class="autosummary longtable table autosummary"> <colgroup> <col style="width: 10%" /> <col style="width: 90%" /> </colgroup> <tbody> <tr class="row-odd"><td><p><a class="reference internal" href="#flax.linen.LSTMCell.initialize_carry" title="flax.linen.LSTMCell.initialize_carry"><code class="xref py py-obj docutils literal notranslate"><span class="pre">initialize_carry</span></code></a>(rng, input_shape)</p></td> <td><p>Initialize the RNN cell carry.</p></td> </tr> </tbody> </table> </div> </dd></dl> </div> <div class="docutils container"> <dl class="py class"> <dt class="sig sig-object py" id="flax.linen.OptimizedLSTMCell"> <em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">flax.linen.</span></span><span class="sig-name descname"><span class="pre">OptimizedLSTMCell</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">features</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">gate_fn=<PjitFunction</span> <span class="pre">of</span> <span class="pre"><function</span> <span class="pre">sigmoid>></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">activation_fn=<PjitFunction</span> <span class="pre">of</span> <span class="pre"><function</span> <span class="pre">tanh>></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">kernel_init=<function</span> <span class="pre">variance_scaling.<locals>.init></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">recurrent_kernel_init=<function</span> <span class="pre">orthogonal.<locals>.init></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">bias_init=<function</span> <span class="pre">zeros></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">param_dtype=<class</span> <span class="pre">'jax.numpy.float32'></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">carry_init=<function</span> <span class="pre">zeros></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">parent=<flax.linen.module._Sentinel</span> <span class="pre">object></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">name=None</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/recurrent.html#OptimizedLSTMCell"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.OptimizedLSTMCell" title="Permalink to this definition">#</a></dt> <dd><p>More efficient LSTM Cell that concatenates state components before matmul.</p> <p>The parameters are compatible with <code class="docutils literal notranslate"><span class="pre">LSTMCell</span></code>. Note that this cell is often faster than <code class="docutils literal notranslate"><span class="pre">LSTMCell</span></code> as long as the hidden size is roughly <= 2048 units.</p> <p>The mathematical definition of the cell is the same as <code class="docutils literal notranslate"><span class="pre">LSTMCell</span></code> and as follows</p> <div class="math notranslate nohighlight"> \[\begin{split}\begin{array}{ll} i = \sigma(W_{ii} x + W_{hi} h + b_{hi}) \\ f = \sigma(W_{if} x + W_{hf} h + b_{hf}) \\ g = \tanh(W_{ig} x + W_{hg} h + b_{hg}) \\ o = \sigma(W_{io} x + W_{ho} h + b_{ho}) \\ c' = f * c + i * g \\ h' = o * \tanh(c') \\ \end{array}\end{split}\]</div> <p>where x is the input, h is the output of the previous time step, and c is the memory.</p> <p>Example usage:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="kn">import</span> <span class="nn">flax.linen</span> <span class="k">as</span> <span class="nn">nn</span> <span class="gp">>>> </span><span class="kn">import</span> <span class="nn">jax</span><span class="o">,</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span> <span class="gp">>>> </span><span class="n">x</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span> <span class="gp">>>> </span><span class="n">layer</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">OptimizedLSTMCell</span><span class="p">(</span><span class="n">features</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">carry</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">initialize_carry</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">variables</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span> <span class="n">carry</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">new_carry</span><span class="p">,</span> <span class="n">out</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">carry</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> </pre></div> </div> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.OptimizedLSTMCell.gate_fn"> <span class="sig-name descname"><span class="pre">gate_fn</span></span><a class="headerlink" href="#flax.linen.OptimizedLSTMCell.gate_fn" title="Permalink to this definition">#</a></dt> <dd><p>activation function used for gates (default: sigmoid).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>collections.abc.Callable[[…], Any]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.OptimizedLSTMCell.activation_fn"> <span class="sig-name descname"><span class="pre">activation_fn</span></span><a class="headerlink" href="#flax.linen.OptimizedLSTMCell.activation_fn" title="Permalink to this definition">#</a></dt> <dd><p>activation function used for output and memory update (default: tanh).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>collections.abc.Callable[[…], Any]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.OptimizedLSTMCell.kernel_init"> <span class="sig-name descname"><span class="pre">kernel_init</span></span><a class="headerlink" href="#flax.linen.OptimizedLSTMCell.kernel_init" title="Permalink to this definition">#</a></dt> <dd><p>initializer function for the kernels that transform the input (default: lecun_normal).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.OptimizedLSTMCell.recurrent_kernel_init"> <span class="sig-name descname"><span class="pre">recurrent_kernel_init</span></span><a class="headerlink" href="#flax.linen.OptimizedLSTMCell.recurrent_kernel_init" title="Permalink to this definition">#</a></dt> <dd><p>initializer function for the kernels that transform the hidden state (default: initializers.orthogonal()).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.OptimizedLSTMCell.bias_init"> <span class="sig-name descname"><span class="pre">bias_init</span></span><a class="headerlink" href="#flax.linen.OptimizedLSTMCell.bias_init" title="Permalink to this definition">#</a></dt> <dd><p>initializer for the bias parameters (default: initializers.zeros_init()).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.OptimizedLSTMCell.dtype"> <span class="sig-name descname"><span class="pre">dtype</span></span><a class="headerlink" href="#flax.linen.OptimizedLSTMCell.dtype" title="Permalink to this definition">#</a></dt> <dd><p>the dtype of the computation (default: infer from inputs and params).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Optional[Union[str, type[<em>Any</em>], numpy.dtype, jax._src.typing.SupportsDType, Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.OptimizedLSTMCell.param_dtype"> <span class="sig-name descname"><span class="pre">param_dtype</span></span><a class="headerlink" href="#flax.linen.OptimizedLSTMCell.param_dtype" title="Permalink to this definition">#</a></dt> <dd><p>the dtype passed to parameter initializers (default: float32).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[str, type[<em>Any</em>], numpy.dtype, jax._src.typing.SupportsDType, Any]</p> </dd> </dl> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.OptimizedLSTMCell.__call__"> <span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">carry</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">inputs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/recurrent.html#OptimizedLSTMCell.__call__"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.OptimizedLSTMCell.__call__" title="Permalink to this definition">#</a></dt> <dd><p>An optimized long short-term memory (LSTM) cell.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>carry</strong> – the hidden state of the LSTM cell, initialized using <code class="docutils literal notranslate"><span class="pre">LSTMCell.initialize_carry</span></code>.</p></li> <li><p><strong>inputs</strong> – an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>A tuple with the new carry and the output.</p> </dd> </dl> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.OptimizedLSTMCell.initialize_carry"> <span class="sig-name descname"><span class="pre">initialize_carry</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">rng</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">input_shape</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/recurrent.html#OptimizedLSTMCell.initialize_carry"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.OptimizedLSTMCell.initialize_carry" title="Permalink to this definition">#</a></dt> <dd><p>Initialize the RNN cell carry.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>rng</strong> – random number generator passed to the init_fn.</p></li> <li><p><strong>input_shape</strong> – a tuple providing the shape of the input to the cell.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>An initialized carry for the given RNN cell.</p> </dd> </dl> </dd></dl> <p class="rubric">Methods</p> <div class="pst-scrollable-table-container"><table class="autosummary longtable table autosummary"> <colgroup> <col style="width: 10%" /> <col style="width: 90%" /> </colgroup> <tbody> <tr class="row-odd"><td><p><a class="reference internal" href="#flax.linen.OptimizedLSTMCell.initialize_carry" title="flax.linen.OptimizedLSTMCell.initialize_carry"><code class="xref py py-obj docutils literal notranslate"><span class="pre">initialize_carry</span></code></a>(rng, input_shape)</p></td> <td><p>Initialize the RNN cell carry.</p></td> </tr> </tbody> </table> </div> </dd></dl> </div> <div class="docutils container"> <dl class="py class"> <dt class="sig sig-object py" id="flax.linen.ConvLSTMCell"> <em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">flax.linen.</span></span><span class="sig-name descname"><span class="pre">ConvLSTMCell</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">features</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">kernel_size</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">strides=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">padding='SAME'</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">use_bias=True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">param_dtype=<class</span> <span class="pre">'jax.numpy.float32'></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">carry_init=<function</span> <span class="pre">zeros></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">parent=<flax.linen.module._Sentinel</span> <span class="pre">object></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">name=None</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/recurrent.html#ConvLSTMCell"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.ConvLSTMCell" title="Permalink to this definition">#</a></dt> <dd><p>A convolutional LSTM cell.</p> <p>The implementation is based on xingjian2015convolutional. Given x_t and the previous state (h_{t-1}, c_{t-1}) the core computes</p> <div class="math notranslate nohighlight"> \[\begin{split}\begin{array}{ll} i_t = \sigma(W_{ii} * x_t + W_{hi} * h_{t-1} + b_i) \\ f_t = \sigma(W_{if} * x_t + W_{hf} * h_{t-1} + b_f) \\ g_t = \tanh(W_{ig} * x_t + W_{hg} * h_{t-1} + b_g) \\ o_t = \sigma(W_{io} * x_t + W_{ho} * h_{t-1} + b_o) \\ c_t = f_t c_{t-1} + i_t g_t \\ h_t = o_t \tanh(c_t) \end{array}\end{split}\]</div> <p>where * denotes the convolution operator; i_t, f_t, o_t are input, forget and output gate activations, and g_t is a vector of cell updates.</p> <div class="admonition note"> <p class="admonition-title">Note</p> <dl class="simple"> <dt>Forget gate initialization:</dt><dd><p>Following jozefowicz2015empirical we add 1.0 to b_f after initialization in order to reduce the scale of forgetting in the beginning of the training.</p> </dd> </dl> </div> <p>Example usage:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="kn">import</span> <span class="nn">flax.linen</span> <span class="k">as</span> <span class="nn">nn</span> <span class="gp">>>> </span><span class="kn">import</span> <span class="nn">jax</span><span class="o">,</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span> <span class="gp">>>> </span><span class="n">x</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">5</span><span class="p">))</span> <span class="gp">>>> </span><span class="n">layer</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ConvLSTMCell</span><span class="p">(</span><span class="n">features</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">))</span> <span class="gp">>>> </span><span class="n">carry</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">initialize_carry</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">variables</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span> <span class="n">carry</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">new_carry</span><span class="p">,</span> <span class="n">out</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">carry</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> </pre></div> </div> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.ConvLSTMCell.features"> <span class="sig-name descname"><span class="pre">features</span></span><a class="headerlink" href="#flax.linen.ConvLSTMCell.features" title="Permalink to this definition">#</a></dt> <dd><p>number of convolution filters.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>int</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.ConvLSTMCell.kernel_size"> <span class="sig-name descname"><span class="pre">kernel_size</span></span><a class="headerlink" href="#flax.linen.ConvLSTMCell.kernel_size" title="Permalink to this definition">#</a></dt> <dd><p>shape of the convolutional kernel.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>collections.abc.Sequence[int]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.ConvLSTMCell.strides"> <span class="sig-name descname"><span class="pre">strides</span></span><a class="headerlink" href="#flax.linen.ConvLSTMCell.strides" title="Permalink to this definition">#</a></dt> <dd><p>a sequence of <code class="docutils literal notranslate"><span class="pre">n</span></code> integers, representing the inter-window strides.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>collections.abc.Sequence[int] | None</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.ConvLSTMCell.padding"> <span class="sig-name descname"><span class="pre">padding</span></span><a class="headerlink" href="#flax.linen.ConvLSTMCell.padding" title="Permalink to this definition">#</a></dt> <dd><p>either the string <code class="docutils literal notranslate"><span class="pre">'SAME'</span></code>, the string <code class="docutils literal notranslate"><span class="pre">'VALID'</span></code>, or a sequence of <code class="docutils literal notranslate"><span class="pre">n</span></code> <code class="docutils literal notranslate"><span class="pre">(low,</span> <span class="pre">high)</span></code> integer pairs that give the padding to apply before and after each spatial dimension.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>str | collections.abc.Sequence[tuple[int, int]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.ConvLSTMCell.bias"> <span class="sig-name descname"><span class="pre">bias</span></span><a class="headerlink" href="#flax.linen.ConvLSTMCell.bias" title="Permalink to this definition">#</a></dt> <dd><p>whether to add a bias to the output (default: True).</p> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.ConvLSTMCell.dtype"> <span class="sig-name descname"><span class="pre">dtype</span></span><a class="headerlink" href="#flax.linen.ConvLSTMCell.dtype" title="Permalink to this definition">#</a></dt> <dd><p>the dtype of the computation (default: None).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Optional[Union[str, type[<em>Any</em>], numpy.dtype, jax._src.typing.SupportsDType, Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.ConvLSTMCell.param_dtype"> <span class="sig-name descname"><span class="pre">param_dtype</span></span><a class="headerlink" href="#flax.linen.ConvLSTMCell.param_dtype" title="Permalink to this definition">#</a></dt> <dd><p>the dtype passed to parameter initializers (default: float32).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[str, type[<em>Any</em>], numpy.dtype, jax._src.typing.SupportsDType, Any]</p> </dd> </dl> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.ConvLSTMCell.__call__"> <span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">carry</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">inputs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/recurrent.html#ConvLSTMCell.__call__"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.ConvLSTMCell.__call__" title="Permalink to this definition">#</a></dt> <dd><p>Constructs a convolutional LSTM.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>carry</strong> – the hidden state of the Conv2DLSTM cell, initialized using <code class="docutils literal notranslate"><span class="pre">Conv2DLSTM.initialize_carry</span></code>.</p></li> <li><p><strong>inputs</strong> – input data with dimensions (batch, spatial_dims…, features).</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>A tuple with the new carry and the output.</p> </dd> </dl> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.ConvLSTMCell.initialize_carry"> <span class="sig-name descname"><span class="pre">initialize_carry</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">rng</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">input_shape</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/recurrent.html#ConvLSTMCell.initialize_carry"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.ConvLSTMCell.initialize_carry" title="Permalink to this definition">#</a></dt> <dd><p>Initialize the RNN cell carry.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>rng</strong> – random number generator passed to the init_fn.</p></li> <li><p><strong>input_shape</strong> – a tuple providing the shape of the input to the cell.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>An initialized carry for the given RNN cell.</p> </dd> </dl> </dd></dl> <p class="rubric">Methods</p> <div class="pst-scrollable-table-container"><table class="autosummary longtable table autosummary"> <colgroup> <col style="width: 10%" /> <col style="width: 90%" /> </colgroup> <tbody> <tr class="row-odd"><td><p><a class="reference internal" href="#flax.linen.ConvLSTMCell.initialize_carry" title="flax.linen.ConvLSTMCell.initialize_carry"><code class="xref py py-obj docutils literal notranslate"><span class="pre">initialize_carry</span></code></a>(rng, input_shape)</p></td> <td><p>Initialize the RNN cell carry.</p></td> </tr> </tbody> </table> </div> </dd></dl> </div> <div class="docutils container"> <dl class="py class"> <dt class="sig sig-object py" id="flax.linen.SimpleCell"> <em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">flax.linen.</span></span><span class="sig-name descname"><span class="pre">SimpleCell</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">features</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">activation_fn=<PjitFunction</span> <span class="pre">of</span> <span class="pre"><function</span> <span class="pre">tanh>></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">kernel_init=<function</span> <span class="pre">variance_scaling.<locals>.init></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">recurrent_kernel_init=<function</span> <span class="pre">orthogonal.<locals>.init></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">bias_init=<function</span> <span class="pre">zeros></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">param_dtype=<class</span> <span class="pre">'jax.numpy.float32'></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">carry_init=<function</span> <span class="pre">zeros></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">residual=False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">parent=<flax.linen.module._Sentinel</span> <span class="pre">object></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">name=None</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/recurrent.html#SimpleCell"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.SimpleCell" title="Permalink to this definition">#</a></dt> <dd><p>Simple cell.</p> <p>The mathematical definition of the cell is as follows</p> <div class="math notranslate nohighlight"> \[\begin{array}{ll} h' = \tanh(W_i x + b_i + W_h h) \end{array}\]</div> <p>where x is the input and h is the output of the previous time step.</p> <p>If <cite>residual</cite> is <cite>True</cite>,</p> <div class="math notranslate nohighlight"> \[\begin{array}{ll} h' = \tanh(W_i x + b_i + W_h h + h) \end{array}\]</div> <p>Example usage:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="kn">import</span> <span class="nn">flax.linen</span> <span class="k">as</span> <span class="nn">nn</span> <span class="gp">>>> </span><span class="kn">import</span> <span class="nn">jax</span><span class="o">,</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span> <span class="gp">>>> </span><span class="n">x</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span> <span class="gp">>>> </span><span class="n">layer</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">SimpleCell</span><span class="p">(</span><span class="n">features</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">carry</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">initialize_carry</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">variables</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span> <span class="n">carry</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">new_carry</span><span class="p">,</span> <span class="n">out</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">carry</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> </pre></div> </div> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.SimpleCell.features"> <span class="sig-name descname"><span class="pre">features</span></span><a class="headerlink" href="#flax.linen.SimpleCell.features" title="Permalink to this definition">#</a></dt> <dd><p>number of output features.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>int</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.SimpleCell.activation_fn"> <span class="sig-name descname"><span class="pre">activation_fn</span></span><a class="headerlink" href="#flax.linen.SimpleCell.activation_fn" title="Permalink to this definition">#</a></dt> <dd><p>activation function used for output and memory update (default: tanh).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>collections.abc.Callable[[…], Any]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.SimpleCell.kernel_init"> <span class="sig-name descname"><span class="pre">kernel_init</span></span><a class="headerlink" href="#flax.linen.SimpleCell.kernel_init" title="Permalink to this definition">#</a></dt> <dd><p>initializer function for the kernels that transform the input (default: lecun_normal).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.SimpleCell.recurrent_kernel_init"> <span class="sig-name descname"><span class="pre">recurrent_kernel_init</span></span><a class="headerlink" href="#flax.linen.SimpleCell.recurrent_kernel_init" title="Permalink to this definition">#</a></dt> <dd><p>initializer function for the kernels that transform the hidden state (default: initializers.orthogonal()).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.SimpleCell.bias_init"> <span class="sig-name descname"><span class="pre">bias_init</span></span><a class="headerlink" href="#flax.linen.SimpleCell.bias_init" title="Permalink to this definition">#</a></dt> <dd><p>initializer for the bias parameters (default: initializers.zeros_init())</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.SimpleCell.dtype"> <span class="sig-name descname"><span class="pre">dtype</span></span><a class="headerlink" href="#flax.linen.SimpleCell.dtype" title="Permalink to this definition">#</a></dt> <dd><p>the dtype of the computation (default: None).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Optional[Union[str, type[<em>Any</em>], numpy.dtype, jax._src.typing.SupportsDType, Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.SimpleCell.param_dtype"> <span class="sig-name descname"><span class="pre">param_dtype</span></span><a class="headerlink" href="#flax.linen.SimpleCell.param_dtype" title="Permalink to this definition">#</a></dt> <dd><p>the dtype passed to parameter initializers (default: float32).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[str, type[<em>Any</em>], numpy.dtype, jax._src.typing.SupportsDType, Any]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.SimpleCell.residual"> <span class="sig-name descname"><span class="pre">residual</span></span><a class="headerlink" href="#flax.linen.SimpleCell.residual" title="Permalink to this definition">#</a></dt> <dd><p>pre-activation residual connection (<a class="reference external" href="https://arxiv.org/abs/1801.06105">https://arxiv.org/abs/1801.06105</a>).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>bool</p> </dd> </dl> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.SimpleCell.__call__"> <span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">carry</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">inputs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/recurrent.html#SimpleCell.__call__"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.SimpleCell.__call__" title="Permalink to this definition">#</a></dt> <dd><p>Simple cell.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>carry</strong> – the hidden state of the Simple cell, initialized using <code class="docutils literal notranslate"><span class="pre">SimpleCell.initialize_carry</span></code>.</p></li> <li><p><strong>inputs</strong> – an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>A tuple with the new carry and the output.</p> </dd> </dl> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.SimpleCell.initialize_carry"> <span class="sig-name descname"><span class="pre">initialize_carry</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">rng</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">input_shape</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/recurrent.html#SimpleCell.initialize_carry"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.SimpleCell.initialize_carry" title="Permalink to this definition">#</a></dt> <dd><p>Initialize the RNN cell carry.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>rng</strong> – random number generator passed to the init_fn.</p></li> <li><p><strong>input_shape</strong> – a tuple providing the shape of the input to the cell.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>An initialized carry for the given RNN cell.</p> </dd> </dl> </dd></dl> <p class="rubric">Methods</p> <div class="pst-scrollable-table-container"><table class="autosummary longtable table autosummary"> <colgroup> <col style="width: 10%" /> <col style="width: 90%" /> </colgroup> <tbody> <tr class="row-odd"><td><p><a class="reference internal" href="#flax.linen.SimpleCell.initialize_carry" title="flax.linen.SimpleCell.initialize_carry"><code class="xref py py-obj docutils literal notranslate"><span class="pre">initialize_carry</span></code></a>(rng, input_shape)</p></td> <td><p>Initialize the RNN cell carry.</p></td> </tr> </tbody> </table> </div> </dd></dl> </div> <div class="docutils container"> <dl class="py class"> <dt class="sig sig-object py" id="flax.linen.GRUCell"> <em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">flax.linen.</span></span><span class="sig-name descname"><span class="pre">GRUCell</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">features</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">gate_fn=<PjitFunction</span> <span class="pre">of</span> <span class="pre"><function</span> <span class="pre">sigmoid>></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">activation_fn=<PjitFunction</span> <span class="pre">of</span> <span class="pre"><function</span> <span class="pre">tanh>></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">kernel_init=<function</span> <span class="pre">variance_scaling.<locals>.init></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">recurrent_kernel_init=<function</span> <span class="pre">orthogonal.<locals>.init></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">bias_init=<function</span> <span class="pre">zeros></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">param_dtype=<class</span> <span class="pre">'jax.numpy.float32'></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">carry_init=<function</span> <span class="pre">zeros></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">parent=<flax.linen.module._Sentinel</span> <span class="pre">object></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">name=None</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/recurrent.html#GRUCell"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.GRUCell" title="Permalink to this definition">#</a></dt> <dd><p>GRU cell.</p> <p>The mathematical definition of the cell is as follows</p> <div class="math notranslate nohighlight"> \[\begin{split}\begin{array}{ll} r = \sigma(W_{ir} x + b_{ir} + W_{hr} h) \\ z = \sigma(W_{iz} x + b_{iz} + W_{hz} h) \\ n = \tanh(W_{in} x + b_{in} + r * (W_{hn} h + b_{hn})) \\ h' = (1 - z) * n + z * h \\ \end{array}\end{split}\]</div> <p>where x is the input and h is the output of the previous time step.</p> <p>Example usage:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="kn">import</span> <span class="nn">flax.linen</span> <span class="k">as</span> <span class="nn">nn</span> <span class="gp">>>> </span><span class="kn">import</span> <span class="nn">jax</span><span class="o">,</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span> <span class="gp">>>> </span><span class="n">x</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span> <span class="gp">>>> </span><span class="n">layer</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">GRUCell</span><span class="p">(</span><span class="n">features</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">carry</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">initialize_carry</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">variables</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span> <span class="n">carry</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">new_carry</span><span class="p">,</span> <span class="n">out</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">carry</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> </pre></div> </div> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.GRUCell.features"> <span class="sig-name descname"><span class="pre">features</span></span><a class="headerlink" href="#flax.linen.GRUCell.features" title="Permalink to this definition">#</a></dt> <dd><p>number of output features.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>int</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.GRUCell.gate_fn"> <span class="sig-name descname"><span class="pre">gate_fn</span></span><a class="headerlink" href="#flax.linen.GRUCell.gate_fn" title="Permalink to this definition">#</a></dt> <dd><p>activation function used for gates (default: sigmoid).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>collections.abc.Callable[[…], Any]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.GRUCell.activation_fn"> <span class="sig-name descname"><span class="pre">activation_fn</span></span><a class="headerlink" href="#flax.linen.GRUCell.activation_fn" title="Permalink to this definition">#</a></dt> <dd><p>activation function used for output and memory update (default: tanh).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>collections.abc.Callable[[…], Any]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.GRUCell.kernel_init"> <span class="sig-name descname"><span class="pre">kernel_init</span></span><a class="headerlink" href="#flax.linen.GRUCell.kernel_init" title="Permalink to this definition">#</a></dt> <dd><p>initializer function for the kernels that transform the input (default: lecun_normal).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.GRUCell.recurrent_kernel_init"> <span class="sig-name descname"><span class="pre">recurrent_kernel_init</span></span><a class="headerlink" href="#flax.linen.GRUCell.recurrent_kernel_init" title="Permalink to this definition">#</a></dt> <dd><p>initializer function for the kernels that transform the hidden state (default: initializers.orthogonal()).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.GRUCell.bias_init"> <span class="sig-name descname"><span class="pre">bias_init</span></span><a class="headerlink" href="#flax.linen.GRUCell.bias_init" title="Permalink to this definition">#</a></dt> <dd><p>initializer for the bias parameters (default: initializers.zeros_init())</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.GRUCell.dtype"> <span class="sig-name descname"><span class="pre">dtype</span></span><a class="headerlink" href="#flax.linen.GRUCell.dtype" title="Permalink to this definition">#</a></dt> <dd><p>the dtype of the computation (default: None).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Optional[Union[str, type[<em>Any</em>], numpy.dtype, jax._src.typing.SupportsDType, Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.GRUCell.param_dtype"> <span class="sig-name descname"><span class="pre">param_dtype</span></span><a class="headerlink" href="#flax.linen.GRUCell.param_dtype" title="Permalink to this definition">#</a></dt> <dd><p>the dtype passed to parameter initializers (default: float32).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[str, type[<em>Any</em>], numpy.dtype, jax._src.typing.SupportsDType, Any]</p> </dd> </dl> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.GRUCell.__call__"> <span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">carry</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">inputs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/recurrent.html#GRUCell.__call__"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.GRUCell.__call__" title="Permalink to this definition">#</a></dt> <dd><p>Gated recurrent unit (GRU) cell.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>carry</strong> – the hidden state of the GRU cell, initialized using <code class="docutils literal notranslate"><span class="pre">GRUCell.initialize_carry</span></code>.</p></li> <li><p><strong>inputs</strong> – an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>A tuple with the new carry and the output.</p> </dd> </dl> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.GRUCell.initialize_carry"> <span class="sig-name descname"><span class="pre">initialize_carry</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">rng</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">input_shape</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/recurrent.html#GRUCell.initialize_carry"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.GRUCell.initialize_carry" title="Permalink to this definition">#</a></dt> <dd><p>Initialize the RNN cell carry.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>rng</strong> – random number generator passed to the init_fn.</p></li> <li><p><strong>input_shape</strong> – a tuple providing the shape of the input to the cell.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>An initialized carry for the given RNN cell.</p> </dd> </dl> </dd></dl> <p class="rubric">Methods</p> <div class="pst-scrollable-table-container"><table class="autosummary longtable table autosummary"> <colgroup> <col style="width: 10%" /> <col style="width: 90%" /> </colgroup> <tbody> <tr class="row-odd"><td><p><a class="reference internal" href="#flax.linen.GRUCell.initialize_carry" title="flax.linen.GRUCell.initialize_carry"><code class="xref py py-obj docutils literal notranslate"><span class="pre">initialize_carry</span></code></a>(rng, input_shape)</p></td> <td><p>Initialize the RNN cell carry.</p></td> </tr> </tbody> </table> </div> </dd></dl> </div> <div class="docutils container"> <dl class="py class"> <dt class="sig sig-object py" id="flax.linen.MGUCell"> <em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">flax.linen.</span></span><span class="sig-name descname"><span class="pre">MGUCell</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">features</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">gate_fn=<PjitFunction</span> <span class="pre">of</span> <span class="pre"><function</span> <span class="pre">sigmoid>></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">activation_fn=<PjitFunction</span> <span class="pre">of</span> <span class="pre"><function</span> <span class="pre">tanh>></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">kernel_init=<function</span> <span class="pre">variance_scaling.<locals>.init></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">recurrent_kernel_init=<function</span> <span class="pre">orthogonal.<locals>.init></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">forget_bias_init=<function</span> <span class="pre">ones></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">activation_bias_init=<function</span> <span class="pre">zeros></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">param_dtype=<class</span> <span class="pre">'jax.numpy.float32'></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">carry_init=<function</span> <span class="pre">zeros></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">reset_gate=True</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">parent=<flax.linen.module._Sentinel</span> <span class="pre">object></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">name=None</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/recurrent.html#MGUCell"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.MGUCell" title="Permalink to this definition">#</a></dt> <dd><p>MGU cell (<a class="reference external" href="https://arxiv.org/pdf/1603.09420.pdf">https://arxiv.org/pdf/1603.09420.pdf</a>).</p> <p>The mathematical definition of the cell is as follows</p> <div class="math notranslate nohighlight"> \[\begin{split}\begin{array}{ll} f = \sigma(W_{if} x + b_{if} + W_{hf} h) \\ n = \tanh(W_{in} x + b_{in} + f * (W_{hn} h + b_{hn})) \\ h' = (1 - f) * n + f * h \\ \end{array}\end{split}\]</div> <p>where x is the input and h is the output of the previous time step.</p> <p>If <code class="docutils literal notranslate"><span class="pre">reset_gate</span></code> is false, the above becomes</p> <div class="math notranslate nohighlight"> \[\begin{split}\begin{array}{ll} f = \sigma(W_{if} x + b_{if} + W_{hf} h) \\ n = \tanh(W_{in} x + b_{in} + W_{hn} h) \\ h' = (1 - f) * n + f * h \\ \end{array}\end{split}\]</div> <p>Example usage:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="kn">import</span> <span class="nn">flax.linen</span> <span class="k">as</span> <span class="nn">nn</span> <span class="gp">>>> </span><span class="kn">import</span> <span class="nn">jax</span><span class="o">,</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span> <span class="gp">>>> </span><span class="n">x</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span> <span class="gp">>>> </span><span class="n">layer</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">MGUCell</span><span class="p">(</span><span class="n">features</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">carry</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">initialize_carry</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">variables</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span> <span class="n">carry</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">new_carry</span><span class="p">,</span> <span class="n">out</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">carry</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> </pre></div> </div> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.MGUCell.features"> <span class="sig-name descname"><span class="pre">features</span></span><a class="headerlink" href="#flax.linen.MGUCell.features" title="Permalink to this definition">#</a></dt> <dd><p>number of output features.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>int</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.MGUCell.gate_fn"> <span class="sig-name descname"><span class="pre">gate_fn</span></span><a class="headerlink" href="#flax.linen.MGUCell.gate_fn" title="Permalink to this definition">#</a></dt> <dd><p>activation function used for gates (default: sigmoid).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>collections.abc.Callable[[…], Any]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.MGUCell.activation_fn"> <span class="sig-name descname"><span class="pre">activation_fn</span></span><a class="headerlink" href="#flax.linen.MGUCell.activation_fn" title="Permalink to this definition">#</a></dt> <dd><p>activation function used for output and memory update (default: tanh).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>collections.abc.Callable[[…], Any]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.MGUCell.kernel_init"> <span class="sig-name descname"><span class="pre">kernel_init</span></span><a class="headerlink" href="#flax.linen.MGUCell.kernel_init" title="Permalink to this definition">#</a></dt> <dd><p>initializer function for the kernels that transform the input (default: lecun_normal).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.MGUCell.recurrent_kernel_init"> <span class="sig-name descname"><span class="pre">recurrent_kernel_init</span></span><a class="headerlink" href="#flax.linen.MGUCell.recurrent_kernel_init" title="Permalink to this definition">#</a></dt> <dd><p>initializer function for the kernels that transform the hidden state (default: initializers.orthogonal()).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.MGUCell.forget_bias_init"> <span class="sig-name descname"><span class="pre">forget_bias_init</span></span><a class="headerlink" href="#flax.linen.MGUCell.forget_bias_init" title="Permalink to this definition">#</a></dt> <dd><p>initializer for the bias parameters of the forget gate. The default is set to initializers.ones_init() because this prevents vanishing gradients. See <a class="reference external" href="https://proceedings.mlr.press/v37/jozefowicz15.pdf">https://proceedings.mlr.press/v37/jozefowicz15.pdf</a>, section 2.2 for more details.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.MGUCell.activation_bias_init"> <span class="sig-name descname"><span class="pre">activation_bias_init</span></span><a class="headerlink" href="#flax.linen.MGUCell.activation_bias_init" title="Permalink to this definition">#</a></dt> <dd><p>initializer for the bias parameters of the activation output (default: initializers.zeros_init()).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[jax.nn.initializers.Initializer, collections.abc.Callable[[…], Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.MGUCell.dtype"> <span class="sig-name descname"><span class="pre">dtype</span></span><a class="headerlink" href="#flax.linen.MGUCell.dtype" title="Permalink to this definition">#</a></dt> <dd><p>the dtype of the computation (default: None).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Optional[Union[str, type[<em>Any</em>], numpy.dtype, jax._src.typing.SupportsDType, Any]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.MGUCell.param_dtype"> <span class="sig-name descname"><span class="pre">param_dtype</span></span><a class="headerlink" href="#flax.linen.MGUCell.param_dtype" title="Permalink to this definition">#</a></dt> <dd><p>the dtype passed to parameter initializers (default: float32).</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[str, type[<em>Any</em>], numpy.dtype, jax._src.typing.SupportsDType, Any]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.MGUCell.reset_gate"> <span class="sig-name descname"><span class="pre">reset_gate</span></span><a class="headerlink" href="#flax.linen.MGUCell.reset_gate" title="Permalink to this definition">#</a></dt> <dd><p>flag for applying reset gating.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>bool</p> </dd> </dl> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.MGUCell.__call__"> <span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">carry</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">inputs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/recurrent.html#MGUCell.__call__"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.MGUCell.__call__" title="Permalink to this definition">#</a></dt> <dd><p>Minimal gated unit (MGU) cell.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>carry</strong> – the hidden state of the MGU cell, initialized using <code class="docutils literal notranslate"><span class="pre">MGUCell.initialize_carry</span></code>.</p></li> <li><p><strong>inputs</strong> – an ndarray with the input for the current time step. All dimensions except the final are considered batch dimensions.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>A tuple with the new carry and the output.</p> </dd> </dl> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.MGUCell.initialize_carry"> <span class="sig-name descname"><span class="pre">initialize_carry</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">rng</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">input_shape</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/recurrent.html#MGUCell.initialize_carry"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.MGUCell.initialize_carry" title="Permalink to this definition">#</a></dt> <dd><p>Initialize the RNN cell carry.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>rng</strong> – random number generator passed to the init_fn.</p></li> <li><p><strong>input_shape</strong> – a tuple providing the shape of the input to the cell.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>An initialized carry for the given RNN cell.</p> </dd> </dl> </dd></dl> <p class="rubric">Methods</p> <div class="pst-scrollable-table-container"><table class="autosummary longtable table autosummary"> <colgroup> <col style="width: 10%" /> <col style="width: 90%" /> </colgroup> <tbody> <tr class="row-odd"><td><p><a class="reference internal" href="#flax.linen.MGUCell.initialize_carry" title="flax.linen.MGUCell.initialize_carry"><code class="xref py py-obj docutils literal notranslate"><span class="pre">initialize_carry</span></code></a>(rng, input_shape)</p></td> <td><p>Initialize the RNN cell carry.</p></td> </tr> </tbody> </table> </div> </dd></dl> </div> <div class="docutils container"> <dl class="py class"> <dt class="sig sig-object py" id="flax.linen.RNN"> <em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">flax.linen.</span></span><span class="sig-name descname"><span class="pre">RNN</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">cell</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">time_major=False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">return_carry=False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">reverse=False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">keep_order=False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">unroll=1</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">variable_axes=FrozenDict({})</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">variable_broadcast='params'</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">variable_carry=False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">split_rngs=FrozenDict({</span>     <span class="pre">params:</span> <span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">})</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">parent=<flax.linen.module._Sentinel</span> <span class="pre">object></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">name=None</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/recurrent.html#RNN"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.RNN" title="Permalink to this definition">#</a></dt> <dd><p>The <code class="docutils literal notranslate"><span class="pre">RNN</span></code> module takes any <a class="reference internal" href="#flax.linen.RNNCellBase" title="flax.linen.RNNCellBase"><code class="xref py py-class docutils literal notranslate"><span class="pre">RNNCellBase</span></code></a> instance and applies it over a sequence</p> <p>using <a class="reference internal" href="transformations.html#flax.linen.scan" title="flax.linen.scan"><code class="xref py py-func docutils literal notranslate"><span class="pre">flax.linen.scan()</span></code></a>.</p> <p>Example:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="kn">import</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span> <span class="gp">>>> </span><span class="kn">import</span> <span class="nn">jax</span> <span class="gp">>>> </span><span class="kn">import</span> <span class="nn">flax.linen</span> <span class="k">as</span> <span class="nn">nn</span> <span class="gp">>>> </span><span class="n">x</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">10</span><span class="p">,</span> <span class="mi">50</span><span class="p">,</span> <span class="mi">32</span><span class="p">))</span> <span class="c1"># (batch, time, features)</span> <span class="gp">>>> </span><span class="n">lstm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">RNN</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">LSTMCell</span><span class="p">(</span><span class="mi">64</span><span class="p">))</span> <span class="gp">>>> </span><span class="n">variables</span> <span class="o">=</span> <span class="n">lstm</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">y</span> <span class="o">=</span> <span class="n">lstm</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">y</span><span class="o">.</span><span class="n">shape</span> <span class="c1"># (batch, time, cell_size)</span> <span class="go">(10, 50, 64)</span> </pre></div> </div> <p>As shown above, RNN uses the <code class="docutils literal notranslate"><span class="pre">cell_size</span></code> argument to set the <code class="docutils literal notranslate"><span class="pre">size</span></code> argument for the cell’s <code class="docutils literal notranslate"><span class="pre">initialize_carry</span></code> method, in practice this is typically the number of hidden units you want for the cell. However, this may vary depending on the cell you are using, for example the <a class="reference internal" href="#flax.linen.ConvLSTMCell" title="flax.linen.ConvLSTMCell"><code class="xref py py-class docutils literal notranslate"><span class="pre">ConvLSTMCell</span></code></a> requires a <code class="docutils literal notranslate"><span class="pre">size</span></code> argument of the form <code class="docutils literal notranslate"><span class="pre">(kernel_height,</span> <span class="pre">kernel_width,</span> <span class="pre">features)</span></code>:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="n">x</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">10</span><span class="p">,</span> <span class="mi">50</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span> <span class="c1"># (batch, time, height, width, features)</span> <span class="gp">>>> </span><span class="n">conv_lstm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">RNN</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">ConvLSTMCell</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">)))</span> <span class="gp">>>> </span><span class="n">y</span><span class="p">,</span> <span class="n">variables</span> <span class="o">=</span> <span class="n">conv_lstm</span><span class="o">.</span><span class="n">init_with_output</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">y</span><span class="o">.</span><span class="n">shape</span> <span class="c1"># (batch, time, height, width, features)</span> <span class="go">(10, 50, 32, 32, 64)</span> </pre></div> </div> <p>By default RNN expect the time dimension after the batch dimension (<code class="docutils literal notranslate"><span class="pre">(*batch,</span> <span class="pre">time,</span> <span class="pre">*features)</span></code>), if you set <code class="docutils literal notranslate"><span class="pre">time_major=True</span></code> RNN will instead expect the time dimesion to be at the beginning (<code class="docutils literal notranslate"><span class="pre">(time,</span> <span class="pre">*batch,</span> <span class="pre">*features)</span></code>):</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="n">x</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">50</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">32</span><span class="p">))</span> <span class="c1"># (time, batch, features)</span> <span class="gp">>>> </span><span class="n">lstm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">RNN</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">LSTMCell</span><span class="p">(</span><span class="mi">64</span><span class="p">),</span> <span class="n">time_major</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">variables</span> <span class="o">=</span> <span class="n">lstm</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">y</span> <span class="o">=</span> <span class="n">lstm</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">y</span><span class="o">.</span><span class="n">shape</span> <span class="c1"># (time, batch, cell_size)</span> <span class="go">(50, 10, 64)</span> </pre></div> </div> <p>The output is an array of shape <code class="docutils literal notranslate"><span class="pre">(*batch,</span> <span class="pre">time,</span> <span class="pre">*cell_size)</span></code> by default (typically), however if you set <code class="docutils literal notranslate"><span class="pre">return_carry=True</span></code> it will instead return a tuple of the final carry and the output:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="n">x</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">10</span><span class="p">,</span> <span class="mi">50</span><span class="p">,</span> <span class="mi">32</span><span class="p">))</span> <span class="c1"># (batch, time, features)</span> <span class="gp">>>> </span><span class="n">lstm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">RNN</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">LSTMCell</span><span class="p">(</span><span class="mi">64</span><span class="p">),</span> <span class="n">return_carry</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">variables</span> <span class="o">=</span> <span class="n">lstm</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">carry</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">lstm</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">jax</span><span class="o">.</span><span class="n">tree_util</span><span class="o">.</span><span class="n">tree_map</span><span class="p">(</span><span class="n">jnp</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">carry</span><span class="p">)</span> <span class="c1"># ((batch, cell_size), (batch, cell_size))</span> <span class="go">((10, 64), (10, 64))</span> <span class="gp">>>> </span><span class="n">y</span><span class="o">.</span><span class="n">shape</span> <span class="c1"># (batch, time, cell_size)</span> <span class="go">(10, 50, 64)</span> </pre></div> </div> <p>To support variable length sequences, you can pass a <code class="docutils literal notranslate"><span class="pre">seq_lengths</span></code> which is an integer array of shape <code class="docutils literal notranslate"><span class="pre">(*batch)</span></code> where each element is the length of the sequence in the batch. For example:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="n">seq_lengths</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mi">3</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">5</span><span class="p">])</span> </pre></div> </div> <p>The output elements corresponding to padding elements are NOT zeroed out. If <code class="docutils literal notranslate"><span class="pre">return_carry</span></code> is set to <code class="docutils literal notranslate"><span class="pre">True</span></code> the carry will be the state of the last valid element of each sequence.</p> <p>RNN also accepts some of the arguments of <a class="reference internal" href="transformations.html#flax.linen.scan" title="flax.linen.scan"><code class="xref py py-func docutils literal notranslate"><span class="pre">flax.linen.scan()</span></code></a>, by default they are set to work with cells like <a class="reference internal" href="#flax.linen.LSTMCell" title="flax.linen.LSTMCell"><code class="xref py py-class docutils literal notranslate"><span class="pre">LSTMCell</span></code></a> and <a class="reference internal" href="#flax.linen.GRUCell" title="flax.linen.GRUCell"><code class="xref py py-class docutils literal notranslate"><span class="pre">GRUCell</span></code></a> but they can be overriden as needed. Overriding default values to scan looks like this:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="n">lstm</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">RNN</span><span class="p">(</span> <span class="gp">... </span> <span class="n">nn</span><span class="o">.</span><span class="n">LSTMCell</span><span class="p">(</span><span class="mi">64</span><span class="p">),</span> <span class="gp">... </span> <span class="n">unroll</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">variable_axes</span><span class="o">=</span><span class="p">{},</span> <span class="n">variable_broadcast</span><span class="o">=</span><span class="s1">'params'</span><span class="p">,</span> <span class="gp">... </span> <span class="n">variable_carry</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">split_rngs</span><span class="o">=</span><span class="p">{</span><span class="s1">'params'</span><span class="p">:</span> <span class="kc">False</span><span class="p">})</span> </pre></div> </div> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.RNN.cell"> <span class="sig-name descname"><span class="pre">cell</span></span><a class="headerlink" href="#flax.linen.RNN.cell" title="Permalink to this definition">#</a></dt> <dd><p>an instance of <a class="reference internal" href="#flax.linen.RNNCellBase" title="flax.linen.RNNCellBase"><code class="xref py py-class docutils literal notranslate"><span class="pre">RNNCellBase</span></code></a>.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p><a class="reference internal" href="#flax.linen.RNNCellBase" title="flax.linen.recurrent.RNNCellBase">flax.linen.recurrent.RNNCellBase</a></p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.RNN.time_major"> <span class="sig-name descname"><span class="pre">time_major</span></span><a class="headerlink" href="#flax.linen.RNN.time_major" title="Permalink to this definition">#</a></dt> <dd><p>if <code class="docutils literal notranslate"><span class="pre">time_major=False</span></code> (default) it will expect inputs with shape <code class="docutils literal notranslate"><span class="pre">(*batch,</span> <span class="pre">time,</span> <span class="pre">*features)</span></code>, else it will expect inputs with shape <code class="docutils literal notranslate"><span class="pre">(time,</span> <span class="pre">*batch,</span> <span class="pre">*features)</span></code>.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>bool</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.RNN.return_carry"> <span class="sig-name descname"><span class="pre">return_carry</span></span><a class="headerlink" href="#flax.linen.RNN.return_carry" title="Permalink to this definition">#</a></dt> <dd><p>if <code class="docutils literal notranslate"><span class="pre">return_carry=False</span></code> (default) only the output sequence is returned, else it will return a tuple of the final carry and the output sequence.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>bool</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.RNN.reverse"> <span class="sig-name descname"><span class="pre">reverse</span></span><a class="headerlink" href="#flax.linen.RNN.reverse" title="Permalink to this definition">#</a></dt> <dd><p>if <code class="docutils literal notranslate"><span class="pre">reverse=False</span></code> (default) the sequence is processed from left to right and returned in the original order, else it will be processed from right to left, and returned in reverse order. If <code class="docutils literal notranslate"><span class="pre">seq_lengths</span></code> is passed, padding will always remain at the end of the sequence.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>bool</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.RNN.keep_order"> <span class="sig-name descname"><span class="pre">keep_order</span></span><a class="headerlink" href="#flax.linen.RNN.keep_order" title="Permalink to this definition">#</a></dt> <dd><p>if <code class="docutils literal notranslate"><span class="pre">keep_order=True</span></code>, when <code class="docutils literal notranslate"><span class="pre">reverse=True</span></code> the output will be reversed back to the original order after processing, this is useful to align sequences in bidirectional RNNs. If <code class="docutils literal notranslate"><span class="pre">keep_order=False</span></code> (default), the output will remain in the order specified by <code class="docutils literal notranslate"><span class="pre">reverse</span></code>.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>bool</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.RNN.unroll"> <span class="sig-name descname"><span class="pre">unroll</span></span><a class="headerlink" href="#flax.linen.RNN.unroll" title="Permalink to this definition">#</a></dt> <dd><p>how many scan iterations to unroll within a single iteration of a loop, defaults to 1. This argument will be passed to <code class="docutils literal notranslate"><span class="pre">nn.scan</span></code>.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>int</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.RNN.variable_axes"> <span class="sig-name descname"><span class="pre">variable_axes</span></span><a class="headerlink" href="#flax.linen.RNN.variable_axes" title="Permalink to this definition">#</a></dt> <dd><p>a dictionary mapping each collection to either an integer <code class="docutils literal notranslate"><span class="pre">i</span></code> (meaning we scan over dimension <code class="docutils literal notranslate"><span class="pre">i</span></code>) or <code class="docutils literal notranslate"><span class="pre">None</span></code> (replicate rather than scan). This argument is forwarded to <code class="docutils literal notranslate"><span class="pre">nn.scan</span></code>.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>collections.abc.Mapping[Union[bool, str, Collection[str], DenyList], Union[int, flax.typing.In[int], flax.typing.Out[int]]]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.RNN.variable_broadcast"> <span class="sig-name descname"><span class="pre">variable_broadcast</span></span><a class="headerlink" href="#flax.linen.RNN.variable_broadcast" title="Permalink to this definition">#</a></dt> <dd><p>Specifies the broadcasted variable collections. A broadcasted variable should not depend on any computation that cannot be lifted out of the loop. This is typically used to define shared parameters inside the fn. This argument is forwarded to <code class="docutils literal notranslate"><span class="pre">nn.scan</span></code>.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[bool, str, Collection[str], DenyList]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.RNN.variable_carry"> <span class="sig-name descname"><span class="pre">variable_carry</span></span><a class="headerlink" href="#flax.linen.RNN.variable_carry" title="Permalink to this definition">#</a></dt> <dd><p>Specifies the variable collections that are carried through the loop. Mutations to these variables are carried to the next iteration and will be preserved when the scan finishes. This argument is forwarded to <code class="docutils literal notranslate"><span class="pre">nn.scan</span></code>.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>Union[bool, str, Collection[str], DenyList]</p> </dd> </dl> </dd></dl> <dl class="py attribute"> <dt class="sig sig-object py" id="flax.linen.RNN.split_rngs"> <span class="sig-name descname"><span class="pre">split_rngs</span></span><a class="headerlink" href="#flax.linen.RNN.split_rngs" title="Permalink to this definition">#</a></dt> <dd><p>a mapping from PRNGSequenceFilter to bool specifying whether a collection’s PRNG key should be split such that its values are different at each step, or replicated such that its values remain the same at each step. This argument is forwarded to <code class="docutils literal notranslate"><span class="pre">nn.scan</span></code>.</p> <dl class="field-list simple"> <dt class="field-odd">Type</dt> <dd class="field-odd"><p>collections.abc.Mapping[Union[bool, str, Collection[str], DenyList], bool]</p> </dd> </dl> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.RNN.__call__"> <span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">inputs</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">initial_carry</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">init_key</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">seq_lengths</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">return_carry</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">time_major</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">reverse</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">keep_order</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/recurrent.html#RNN.__call__"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.RNN.__call__" title="Permalink to this definition">#</a></dt> <dd><p>Applies the RNN to the inputs.</p> <p><code class="docutils literal notranslate"><span class="pre">__call__</span></code> allows you to optionally override some attributes like <code class="docutils literal notranslate"><span class="pre">return_carry</span></code> and <code class="docutils literal notranslate"><span class="pre">time_major</span></code> defined in the constructor.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>inputs</strong> – the input sequence.</p></li> <li><p><strong>initial_carry</strong> – the initial carry, if not provided it will be initialized using the cell’s <a class="reference internal" href="#flax.linen.RNNCellBase.initialize_carry" title="flax.linen.RNNCellBase.initialize_carry"><code class="xref py py-meth docutils literal notranslate"><span class="pre">RNNCellBase.initialize_carry()</span></code></a> method.</p></li> <li><p><strong>init_key</strong> – a PRNG key used to initialize the carry, if not provided <code class="docutils literal notranslate"><span class="pre">jax.random.key(0)</span></code> will be used. Most cells will ignore this argument.</p></li> <li><p><strong>seq_lengths</strong> – an optional integer array of shape <code class="docutils literal notranslate"><span class="pre">(*batch)</span></code> indicating the length of each sequence, elements whose index in the time dimension is greater than the corresponding length will be considered padding and will be ignored.</p></li> <li><p><strong>return_carry</strong> – if <code class="docutils literal notranslate"><span class="pre">return_carry=False</span></code> (default) only the output sequence is returned, else it will return a tuple of the final carry and the output sequence.</p></li> <li><p><strong>time_major</strong> – if <code class="docutils literal notranslate"><span class="pre">time_major=False</span></code> (default) it will expect inputs with shape <code class="docutils literal notranslate"><span class="pre">(*batch,</span> <span class="pre">time,</span> <span class="pre">*features)</span></code>, else it will expect inputs with shape <code class="docutils literal notranslate"><span class="pre">(time,</span> <span class="pre">*batch,</span> <span class="pre">*features)</span></code>.</p></li> <li><p><strong>reverse</strong> – overrides the <code class="docutils literal notranslate"><span class="pre">reverse</span></code> attribute, if <code class="docutils literal notranslate"><span class="pre">reverse=False</span></code> (default) the sequence is processed from left to right and returned in the original order, else it will be processed from right to left, and returned in reverse order. If <code class="docutils literal notranslate"><span class="pre">seq_lengths</span></code> is passed, padding will always remain at the end of the sequence.</p></li> <li><p><strong>keep_order</strong> – overrides the <code class="docutils literal notranslate"><span class="pre">keep_order</span></code> attribute, if <code class="docutils literal notranslate"><span class="pre">keep_order=True</span></code>, when <code class="docutils literal notranslate"><span class="pre">reverse=True</span></code> the output will be reversed back to the original order after processing, this is useful to align sequences in bidirectional RNNs. If <code class="docutils literal notranslate"><span class="pre">keep_order=False</span></code> (default), the output will remain in the order specified by <code class="docutils literal notranslate"><span class="pre">reverse</span></code>.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>if <code class="docutils literal notranslate"><span class="pre">return_carry=False</span></code> (default) only the output sequence is returned, else it will return a tuple of the final carry and the output sequence.</p> </dd> </dl> </dd></dl> <p class="rubric">Methods</p> <div class="pst-scrollable-table-container"><table class="autosummary longtable table autosummary"> <colgroup> <col style="width: 10%" /> <col style="width: 90%" /> </colgroup> <tbody> </tbody> </table> </div> </dd></dl> </div> <div class="docutils container"> <dl class="py class"> <dt class="sig sig-object py" id="flax.linen.Bidirectional"> <em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">flax.linen.</span></span><span class="sig-name descname"><span class="pre">Bidirectional</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">forward_rnn</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">backward_rnn</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">merge_fn=<function</span> <span class="pre">_concatenate></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">time_major=False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">return_carry=False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">parent=<flax.linen.module._Sentinel</span> <span class="pre">object></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">name=None</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/recurrent.html#Bidirectional"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.Bidirectional" title="Permalink to this definition">#</a></dt> <dd><p>Processes the input in both directions and merges the results.</p> <p>Example usage:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="kn">import</span> <span class="nn">flax.linen</span> <span class="k">as</span> <span class="nn">nn</span> <span class="gp">>>> </span><span class="kn">import</span> <span class="nn">jax</span><span class="o">,</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span> <span class="gp">>>> </span><span class="n">layer</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Bidirectional</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">RNN</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">GRUCell</span><span class="p">(</span><span class="mi">4</span><span class="p">)),</span> <span class="n">nn</span><span class="o">.</span><span class="n">RNN</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">GRUCell</span><span class="p">(</span><span class="mi">4</span><span class="p">)))</span> <span class="gp">>>> </span><span class="n">x</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span> <span class="gp">>>> </span><span class="n">variables</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">out</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> </pre></div> </div> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.Bidirectional.__call__"> <span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">inputs</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">initial_carry</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">init_key</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">seq_lengths</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">return_carry</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">time_major</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">reverse</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">keep_order</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/recurrent.html#Bidirectional.__call__"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.Bidirectional.__call__" title="Permalink to this definition">#</a></dt> <dd><p>Call self as a function.</p> </dd></dl> <p class="rubric">Methods</p> <div class="pst-scrollable-table-container"><table class="autosummary longtable table autosummary"> <colgroup> <col style="width: 10%" /> <col style="width: 90%" /> </colgroup> <tbody> </tbody> </table> </div> </dd></dl> </div> </div> <div class="section" id="batchapply"> <h2>BatchApply<a class="headerlink" href="#batchapply" title="Permalink to this heading">#</a></h2> <div class="docutils container"> <dl class="py class"> <dt class="sig sig-object py" id="flax.linen.BatchApply"> <em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">flax.linen.</span></span><span class="sig-name descname"><span class="pre">BatchApply</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">f</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">num_dims</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">2</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/batch_apply.html#BatchApply"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.BatchApply" title="Permalink to this definition">#</a></dt> <dd><p>Temporarily merges leading dimensions of input tensors.</p> <p>Merges the leading dimensions of a tensor into a single dimension, runs the given callable, then splits the leading dimension of the result to match the input.</p> <p>Input arrays whose rank is smaller than the number of dimensions to collapse are passed unmodified.</p> <p>This may be useful for applying a module to each timestep of e.g. a <code class="docutils literal notranslate"><span class="pre">[Time,</span> <span class="pre">Batch,</span> <span class="pre">...]</span></code> array.</p> <p>For some <code class="docutils literal notranslate"><span class="pre">f</span></code>s and platforms, this may be more efficient than <code class="xref py py-func docutils literal notranslate"><span class="pre">jax.vmap()</span></code>, especially when combined with other transformations like <code class="xref py py-func docutils literal notranslate"><span class="pre">jax.grad()</span></code>.</p> <p>Example usage:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="kn">import</span> <span class="nn">jax</span><span class="o">,</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span> <span class="gp">>>> </span><span class="n">a</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="p">[</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">])</span> <span class="gp">>>> </span><span class="n">b</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="p">[</span><span class="mi">4</span><span class="p">])</span> <span class="gp">>>> </span><span class="k">def</span> <span class="nf">raises</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span> <span class="gp">... </span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">a</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">!=</span> <span class="mi">2</span><span class="p">:</span> <span class="gp">... </span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"a must be shape 2"</span><span class="p">)</span> <span class="gp">... </span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">b</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">!=</span> <span class="mi">1</span><span class="p">:</span> <span class="gp">... </span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"b must be shape 1"</span><span class="p">)</span> <span class="gp">... </span> <span class="k">return</span> <span class="n">jnp</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">out</span> <span class="o">=</span> <span class="n">BatchApply</span><span class="p">(</span><span class="n">raises</span><span class="p">)(</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">expected_merged_leading</span> <span class="o">=</span> <span class="n">raises</span><span class="p">(</span><span class="n">a</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">2</span><span class="o">*</span><span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">),</span> <span class="n">b</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">expected</span> <span class="o">=</span> <span class="n">expected_merged_leading</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span> <span class="o">+</span> <span class="n">expected_merged_leading</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">:])</span> <span class="gp">>>> </span><span class="n">np</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_array_equal</span><span class="p">(</span><span class="n">out</span><span class="p">,</span> <span class="n">expected</span><span class="p">)</span> </pre></div> </div> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.BatchApply.__call__"> <span class="sig-name descname"><span class="pre">__call__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/batch_apply.html#BatchApply.__call__"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.BatchApply.__call__" title="Permalink to this definition">#</a></dt> <dd><p>Call self as a function.</p> </dd></dl> <p class="rubric">Methods</p> <div class="pst-scrollable-table-container"><table class="autosummary longtable table autosummary"> <colgroup> <col style="width: 10%" /> <col style="width: 90%" /> </colgroup> <tbody> </tbody> </table> </div> </dd></dl> </div> </div> </div> </article> <footer class="prev-next-footer d-print-none"> <div class="prev-next-area"> <a class="left-prev" href="init_apply.html" title="previous page"> <i class="fa-solid fa-angle-left"></i> <div class="prev-next-info"> <p class="prev-next-subtitle">previous</p> <p class="prev-next-title">Init/Apply</p> </div> </a> <a class="right-next" href="activation_functions.html" title="next page"> <div class="prev-next-info"> <p class="prev-next-subtitle">next</p> <p class="prev-next-title">Activation functions</p> </div> <i class="fa-solid fa-angle-right"></i> </a> </div> </footer> </div> <div class="bd-sidebar-secondary bd-toc"><div class="sidebar-secondary-items sidebar-secondary__inner"> <div class="sidebar-secondary-item"> <div class="page-toc tocsection onthispage"> <i class="fa-solid fa-list"></i> Contents </div> <nav class="bd-toc-nav page-toc"> <ul class="visible nav section-nav flex-column"> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#linear-modules">Linear Modules</a><ul class="nav section-nav flex-column"> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Dense"><code class="docutils literal notranslate"><span class="pre">Dense</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Dense.features"><code class="docutils literal notranslate"><span class="pre">Dense.features</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Dense.use_bias"><code class="docutils literal notranslate"><span class="pre">Dense.use_bias</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Dense.dtype"><code class="docutils literal notranslate"><span class="pre">Dense.dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Dense.param_dtype"><code class="docutils literal notranslate"><span class="pre">Dense.param_dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Dense.precision"><code class="docutils literal notranslate"><span class="pre">Dense.precision</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Dense.kernel_init"><code class="docutils literal notranslate"><span class="pre">Dense.kernel_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Dense.bias_init"><code class="docutils literal notranslate"><span class="pre">Dense.bias_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Dense.__call__"><code class="docutils literal notranslate"><span class="pre">Dense.__call__()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.DenseGeneral"><code class="docutils literal notranslate"><span class="pre">DenseGeneral</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.DenseGeneral.features"><code class="docutils literal notranslate"><span class="pre">DenseGeneral.features</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.DenseGeneral.axis"><code class="docutils literal notranslate"><span class="pre">DenseGeneral.axis</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.DenseGeneral.batch_dims"><code class="docutils literal notranslate"><span class="pre">DenseGeneral.batch_dims</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.DenseGeneral.use_bias"><code class="docutils literal notranslate"><span class="pre">DenseGeneral.use_bias</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.DenseGeneral.dtype"><code class="docutils literal notranslate"><span class="pre">DenseGeneral.dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.DenseGeneral.param_dtype"><code class="docutils literal notranslate"><span class="pre">DenseGeneral.param_dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.DenseGeneral.kernel_init"><code class="docutils literal notranslate"><span class="pre">DenseGeneral.kernel_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.DenseGeneral.bias_init"><code class="docutils literal notranslate"><span class="pre">DenseGeneral.bias_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.DenseGeneral.precision"><code class="docutils literal notranslate"><span class="pre">DenseGeneral.precision</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.DenseGeneral.__call__"><code class="docutils literal notranslate"><span class="pre">DenseGeneral.__call__()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Conv"><code class="docutils literal notranslate"><span class="pre">Conv</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Conv.features"><code class="docutils literal notranslate"><span class="pre">Conv.features</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Conv.kernel_size"><code class="docutils literal notranslate"><span class="pre">Conv.kernel_size</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Conv.strides"><code class="docutils literal notranslate"><span class="pre">Conv.strides</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Conv.padding"><code class="docutils literal notranslate"><span class="pre">Conv.padding</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Conv.input_dilation"><code class="docutils literal notranslate"><span class="pre">Conv.input_dilation</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Conv.kernel_dilation"><code class="docutils literal notranslate"><span class="pre">Conv.kernel_dilation</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Conv.feature_group_count"><code class="docutils literal notranslate"><span class="pre">Conv.feature_group_count</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Conv.use_bias"><code class="docutils literal notranslate"><span class="pre">Conv.use_bias</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Conv.mask"><code class="docutils literal notranslate"><span class="pre">Conv.mask</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Conv.dtype"><code class="docutils literal notranslate"><span class="pre">Conv.dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Conv.param_dtype"><code class="docutils literal notranslate"><span class="pre">Conv.param_dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Conv.precision"><code class="docutils literal notranslate"><span class="pre">Conv.precision</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Conv.kernel_init"><code class="docutils literal notranslate"><span class="pre">Conv.kernel_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Conv.bias_init"><code class="docutils literal notranslate"><span class="pre">Conv.bias_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Conv.__call__"><code class="docutils literal notranslate"><span class="pre">Conv.__call__()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvTranspose"><code class="docutils literal notranslate"><span class="pre">ConvTranspose</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvTranspose.features"><code class="docutils literal notranslate"><span class="pre">ConvTranspose.features</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvTranspose.kernel_size"><code class="docutils literal notranslate"><span class="pre">ConvTranspose.kernel_size</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvTranspose.strides"><code class="docutils literal notranslate"><span class="pre">ConvTranspose.strides</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvTranspose.padding"><code class="docutils literal notranslate"><span class="pre">ConvTranspose.padding</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvTranspose.kernel_dilation"><code class="docutils literal notranslate"><span class="pre">ConvTranspose.kernel_dilation</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvTranspose.use_bias"><code class="docutils literal notranslate"><span class="pre">ConvTranspose.use_bias</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvTranspose.mask"><code class="docutils literal notranslate"><span class="pre">ConvTranspose.mask</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvTranspose.dtype"><code class="docutils literal notranslate"><span class="pre">ConvTranspose.dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvTranspose.param_dtype"><code class="docutils literal notranslate"><span class="pre">ConvTranspose.param_dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvTranspose.precision"><code class="docutils literal notranslate"><span class="pre">ConvTranspose.precision</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvTranspose.kernel_init"><code class="docutils literal notranslate"><span class="pre">ConvTranspose.kernel_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvTranspose.bias_init"><code class="docutils literal notranslate"><span class="pre">ConvTranspose.bias_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvTranspose.transpose_kernel"><code class="docutils literal notranslate"><span class="pre">ConvTranspose.transpose_kernel</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvTranspose.__call__"><code class="docutils literal notranslate"><span class="pre">ConvTranspose.__call__()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLocal"><code class="docutils literal notranslate"><span class="pre">ConvLocal</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLocal.features"><code class="docutils literal notranslate"><span class="pre">ConvLocal.features</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLocal.kernel_size"><code class="docutils literal notranslate"><span class="pre">ConvLocal.kernel_size</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLocal.strides"><code class="docutils literal notranslate"><span class="pre">ConvLocal.strides</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLocal.padding"><code class="docutils literal notranslate"><span class="pre">ConvLocal.padding</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLocal.input_dilation"><code class="docutils literal notranslate"><span class="pre">ConvLocal.input_dilation</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLocal.kernel_dilation"><code class="docutils literal notranslate"><span class="pre">ConvLocal.kernel_dilation</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLocal.feature_group_count"><code class="docutils literal notranslate"><span class="pre">ConvLocal.feature_group_count</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLocal.use_bias"><code class="docutils literal notranslate"><span class="pre">ConvLocal.use_bias</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLocal.mask"><code class="docutils literal notranslate"><span class="pre">ConvLocal.mask</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLocal.dtype"><code class="docutils literal notranslate"><span class="pre">ConvLocal.dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLocal.param_dtype"><code class="docutils literal notranslate"><span class="pre">ConvLocal.param_dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLocal.precision"><code class="docutils literal notranslate"><span class="pre">ConvLocal.precision</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLocal.kernel_init"><code class="docutils literal notranslate"><span class="pre">ConvLocal.kernel_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLocal.bias_init"><code class="docutils literal notranslate"><span class="pre">ConvLocal.bias_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLocal.__call__"><code class="docutils literal notranslate"><span class="pre">ConvLocal.__call__()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Einsum"><code class="docutils literal notranslate"><span class="pre">Einsum</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Einsum.shape"><code class="docutils literal notranslate"><span class="pre">Einsum.shape</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Einsum.einsum_str"><code class="docutils literal notranslate"><span class="pre">Einsum.einsum_str</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Einsum.use_bias"><code class="docutils literal notranslate"><span class="pre">Einsum.use_bias</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Einsum.dtype"><code class="docutils literal notranslate"><span class="pre">Einsum.dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Einsum.param_dtype"><code class="docutils literal notranslate"><span class="pre">Einsum.param_dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Einsum.precision"><code class="docutils literal notranslate"><span class="pre">Einsum.precision</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Einsum.kernel_init"><code class="docutils literal notranslate"><span class="pre">Einsum.kernel_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Einsum.bias_init"><code class="docutils literal notranslate"><span class="pre">Einsum.bias_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Einsum.__call__"><code class="docutils literal notranslate"><span class="pre">Einsum.__call__()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Embed"><code class="docutils literal notranslate"><span class="pre">Embed</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Embed.num_embeddings"><code class="docutils literal notranslate"><span class="pre">Embed.num_embeddings</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Embed.features"><code class="docutils literal notranslate"><span class="pre">Embed.features</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Embed.dtype"><code class="docutils literal notranslate"><span class="pre">Embed.dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Embed.param_dtype"><code class="docutils literal notranslate"><span class="pre">Embed.param_dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Embed.embedding_init"><code class="docutils literal notranslate"><span class="pre">Embed.embedding_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Embed.__call__"><code class="docutils literal notranslate"><span class="pre">Embed.__call__()</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Embed.attend"><code class="docutils literal notranslate"><span class="pre">Embed.attend()</span></code></a></li> </ul> </li> </ul> </li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#pooling">Pooling</a><ul class="nav section-nav flex-column"> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.max_pool"><code class="docutils literal notranslate"><span class="pre">max_pool()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.avg_pool"><code class="docutils literal notranslate"><span class="pre">avg_pool()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.pool"><code class="docutils literal notranslate"><span class="pre">pool()</span></code></a></li> </ul> </li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#normalization">Normalization</a><ul class="nav section-nav flex-column"> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.BatchNorm"><code class="docutils literal notranslate"><span class="pre">BatchNorm</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.BatchNorm.use_running_average"><code class="docutils literal notranslate"><span class="pre">BatchNorm.use_running_average</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.BatchNorm.axis"><code class="docutils literal notranslate"><span class="pre">BatchNorm.axis</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.BatchNorm.momentum"><code class="docutils literal notranslate"><span class="pre">BatchNorm.momentum</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.BatchNorm.epsilon"><code class="docutils literal notranslate"><span class="pre">BatchNorm.epsilon</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.BatchNorm.dtype"><code class="docutils literal notranslate"><span class="pre">BatchNorm.dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.BatchNorm.param_dtype"><code class="docutils literal notranslate"><span class="pre">BatchNorm.param_dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.BatchNorm.use_bias"><code class="docutils literal notranslate"><span class="pre">BatchNorm.use_bias</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.BatchNorm.use_scale"><code class="docutils literal notranslate"><span class="pre">BatchNorm.use_scale</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.BatchNorm.bias_init"><code class="docutils literal notranslate"><span class="pre">BatchNorm.bias_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.BatchNorm.scale_init"><code class="docutils literal notranslate"><span class="pre">BatchNorm.scale_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.BatchNorm.axis_name"><code class="docutils literal notranslate"><span class="pre">BatchNorm.axis_name</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.BatchNorm.axis_index_groups"><code class="docutils literal notranslate"><span class="pre">BatchNorm.axis_index_groups</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.BatchNorm.use_fast_variance"><code class="docutils literal notranslate"><span class="pre">BatchNorm.use_fast_variance</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.BatchNorm.__call__"><code class="docutils literal notranslate"><span class="pre">BatchNorm.__call__()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.LayerNorm"><code class="docutils literal notranslate"><span class="pre">LayerNorm</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.LayerNorm.epsilon"><code class="docutils literal notranslate"><span class="pre">LayerNorm.epsilon</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.LayerNorm.dtype"><code class="docutils literal notranslate"><span class="pre">LayerNorm.dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.LayerNorm.param_dtype"><code class="docutils literal notranslate"><span class="pre">LayerNorm.param_dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.LayerNorm.use_bias"><code class="docutils literal notranslate"><span class="pre">LayerNorm.use_bias</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.LayerNorm.use_scale"><code class="docutils literal notranslate"><span class="pre">LayerNorm.use_scale</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.LayerNorm.bias_init"><code class="docutils literal notranslate"><span class="pre">LayerNorm.bias_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.LayerNorm.scale_init"><code class="docutils literal notranslate"><span class="pre">LayerNorm.scale_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.LayerNorm.reduction_axes"><code class="docutils literal notranslate"><span class="pre">LayerNorm.reduction_axes</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.LayerNorm.feature_axes"><code class="docutils literal notranslate"><span class="pre">LayerNorm.feature_axes</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.LayerNorm.axis_name"><code class="docutils literal notranslate"><span class="pre">LayerNorm.axis_name</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.LayerNorm.axis_index_groups"><code class="docutils literal notranslate"><span class="pre">LayerNorm.axis_index_groups</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.LayerNorm.use_fast_variance"><code class="docutils literal notranslate"><span class="pre">LayerNorm.use_fast_variance</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.LayerNorm.__call__"><code class="docutils literal notranslate"><span class="pre">LayerNorm.__call__()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GroupNorm"><code class="docutils literal notranslate"><span class="pre">GroupNorm</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GroupNorm.num_groups"><code class="docutils literal notranslate"><span class="pre">GroupNorm.num_groups</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GroupNorm.group_size"><code class="docutils literal notranslate"><span class="pre">GroupNorm.group_size</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GroupNorm.epsilon"><code class="docutils literal notranslate"><span class="pre">GroupNorm.epsilon</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GroupNorm.dtype"><code class="docutils literal notranslate"><span class="pre">GroupNorm.dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GroupNorm.param_dtype"><code class="docutils literal notranslate"><span class="pre">GroupNorm.param_dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GroupNorm.use_bias"><code class="docutils literal notranslate"><span class="pre">GroupNorm.use_bias</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GroupNorm.use_scale"><code class="docutils literal notranslate"><span class="pre">GroupNorm.use_scale</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GroupNorm.bias_init"><code class="docutils literal notranslate"><span class="pre">GroupNorm.bias_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GroupNorm.scale_init"><code class="docutils literal notranslate"><span class="pre">GroupNorm.scale_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GroupNorm.reduction_axes"><code class="docutils literal notranslate"><span class="pre">GroupNorm.reduction_axes</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GroupNorm.axis_name"><code class="docutils literal notranslate"><span class="pre">GroupNorm.axis_name</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GroupNorm.axis_index_groups"><code class="docutils literal notranslate"><span class="pre">GroupNorm.axis_index_groups</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GroupNorm.use_fast_variance"><code class="docutils literal notranslate"><span class="pre">GroupNorm.use_fast_variance</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GroupNorm.__call__"><code class="docutils literal notranslate"><span class="pre">GroupNorm.__call__()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RMSNorm"><code class="docutils literal notranslate"><span class="pre">RMSNorm</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RMSNorm.epsilon"><code class="docutils literal notranslate"><span class="pre">RMSNorm.epsilon</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RMSNorm.dtype"><code class="docutils literal notranslate"><span class="pre">RMSNorm.dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RMSNorm.param_dtype"><code class="docutils literal notranslate"><span class="pre">RMSNorm.param_dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RMSNorm.use_scale"><code class="docutils literal notranslate"><span class="pre">RMSNorm.use_scale</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RMSNorm.scale_init"><code class="docutils literal notranslate"><span class="pre">RMSNorm.scale_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RMSNorm.reduction_axes"><code class="docutils literal notranslate"><span class="pre">RMSNorm.reduction_axes</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RMSNorm.feature_axes"><code class="docutils literal notranslate"><span class="pre">RMSNorm.feature_axes</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RMSNorm.axis_name"><code class="docutils literal notranslate"><span class="pre">RMSNorm.axis_name</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RMSNorm.axis_index_groups"><code class="docutils literal notranslate"><span class="pre">RMSNorm.axis_index_groups</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RMSNorm.use_fast_variance"><code class="docutils literal notranslate"><span class="pre">RMSNorm.use_fast_variance</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RMSNorm.__call__"><code class="docutils literal notranslate"><span class="pre">RMSNorm.__call__()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.InstanceNorm"><code class="docutils literal notranslate"><span class="pre">InstanceNorm</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.InstanceNorm.epsilon"><code class="docutils literal notranslate"><span class="pre">InstanceNorm.epsilon</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.InstanceNorm.dtype"><code class="docutils literal notranslate"><span class="pre">InstanceNorm.dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.InstanceNorm.param_dtype"><code class="docutils literal notranslate"><span class="pre">InstanceNorm.param_dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.InstanceNorm.use_bias"><code class="docutils literal notranslate"><span class="pre">InstanceNorm.use_bias</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.InstanceNorm.use_scale"><code class="docutils literal notranslate"><span class="pre">InstanceNorm.use_scale</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.InstanceNorm.bias_init"><code class="docutils literal notranslate"><span class="pre">InstanceNorm.bias_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.InstanceNorm.scale_init"><code class="docutils literal notranslate"><span class="pre">InstanceNorm.scale_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.InstanceNorm.feature_axes"><code class="docutils literal notranslate"><span class="pre">InstanceNorm.feature_axes</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.InstanceNorm.axis_name"><code class="docutils literal notranslate"><span class="pre">InstanceNorm.axis_name</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.InstanceNorm.axis_index_groups"><code class="docutils literal notranslate"><span class="pre">InstanceNorm.axis_index_groups</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.InstanceNorm.use_fast_variance"><code class="docutils literal notranslate"><span class="pre">InstanceNorm.use_fast_variance</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.InstanceNorm.__call__"><code class="docutils literal notranslate"><span class="pre">InstanceNorm.__call__()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.SpectralNorm"><code class="docutils literal notranslate"><span class="pre">SpectralNorm</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.SpectralNorm.layer_instance"><code class="docutils literal notranslate"><span class="pre">SpectralNorm.layer_instance</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.SpectralNorm.n_steps"><code class="docutils literal notranslate"><span class="pre">SpectralNorm.n_steps</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.SpectralNorm.epsilon"><code class="docutils literal notranslate"><span class="pre">SpectralNorm.epsilon</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.SpectralNorm.dtype"><code class="docutils literal notranslate"><span class="pre">SpectralNorm.dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.SpectralNorm.param_dtype"><code class="docutils literal notranslate"><span class="pre">SpectralNorm.param_dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.SpectralNorm.error_on_non_matrix"><code class="docutils literal notranslate"><span class="pre">SpectralNorm.error_on_non_matrix</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.SpectralNorm.collection_name"><code class="docutils literal notranslate"><span class="pre">SpectralNorm.collection_name</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.SpectralNorm.__call__"><code class="docutils literal notranslate"><span class="pre">SpectralNorm.__call__()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.WeightNorm"><code class="docutils literal notranslate"><span class="pre">WeightNorm</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.WeightNorm.layer_instance"><code class="docutils literal notranslate"><span class="pre">WeightNorm.layer_instance</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.WeightNorm.epsilon"><code class="docutils literal notranslate"><span class="pre">WeightNorm.epsilon</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.WeightNorm.dtype"><code class="docutils literal notranslate"><span class="pre">WeightNorm.dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.WeightNorm.param_dtype"><code class="docutils literal notranslate"><span class="pre">WeightNorm.param_dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.WeightNorm.use_scale"><code class="docutils literal notranslate"><span class="pre">WeightNorm.use_scale</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.WeightNorm.scale_init"><code class="docutils literal notranslate"><span class="pre">WeightNorm.scale_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.WeightNorm.feature_axes"><code class="docutils literal notranslate"><span class="pre">WeightNorm.feature_axes</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.WeightNorm.variable_filter"><code class="docutils literal notranslate"><span class="pre">WeightNorm.variable_filter</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.WeightNorm.__call__"><code class="docutils literal notranslate"><span class="pre">WeightNorm.__call__()</span></code></a></li> </ul> </li> </ul> </li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#combinators">Combinators</a><ul class="nav section-nav flex-column"> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Sequential"><code class="docutils literal notranslate"><span class="pre">Sequential</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Sequential.layers"><code class="docutils literal notranslate"><span class="pre">Sequential.layers</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Sequential.__call__"><code class="docutils literal notranslate"><span class="pre">Sequential.__call__()</span></code></a></li> </ul> </li> </ul> </li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#stochastic">Stochastic</a><ul class="nav section-nav flex-column"> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Dropout"><code class="docutils literal notranslate"><span class="pre">Dropout</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Dropout.rate"><code class="docutils literal notranslate"><span class="pre">Dropout.rate</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Dropout.broadcast_dims"><code class="docutils literal notranslate"><span class="pre">Dropout.broadcast_dims</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Dropout.deterministic"><code class="docutils literal notranslate"><span class="pre">Dropout.deterministic</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Dropout.rng_collection"><code class="docutils literal notranslate"><span class="pre">Dropout.rng_collection</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Dropout.__call__"><code class="docutils literal notranslate"><span class="pre">Dropout.__call__()</span></code></a></li> </ul> </li> </ul> </li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#attention">Attention</a><ul class="nav section-nav flex-column"> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadDotProductAttention"><code class="docutils literal notranslate"><span class="pre">MultiHeadDotProductAttention</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadDotProductAttention.num_heads"><code class="docutils literal notranslate"><span class="pre">MultiHeadDotProductAttention.num_heads</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadDotProductAttention.dtype"><code class="docutils literal notranslate"><span class="pre">MultiHeadDotProductAttention.dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadDotProductAttention.param_dtype"><code class="docutils literal notranslate"><span class="pre">MultiHeadDotProductAttention.param_dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadDotProductAttention.qkv_features"><code class="docutils literal notranslate"><span class="pre">MultiHeadDotProductAttention.qkv_features</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadDotProductAttention.out_features"><code class="docutils literal notranslate"><span class="pre">MultiHeadDotProductAttention.out_features</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadDotProductAttention.broadcast_dropout"><code class="docutils literal notranslate"><span class="pre">MultiHeadDotProductAttention.broadcast_dropout</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadDotProductAttention.dropout_rate"><code class="docutils literal notranslate"><span class="pre">MultiHeadDotProductAttention.dropout_rate</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadDotProductAttention.deterministic"><code class="docutils literal notranslate"><span class="pre">MultiHeadDotProductAttention.deterministic</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadDotProductAttention.precision"><code class="docutils literal notranslate"><span class="pre">MultiHeadDotProductAttention.precision</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadDotProductAttention.kernel_init"><code class="docutils literal notranslate"><span class="pre">MultiHeadDotProductAttention.kernel_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadDotProductAttention.out_kernel_init"><code class="docutils literal notranslate"><span class="pre">MultiHeadDotProductAttention.out_kernel_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadDotProductAttention.bias_init"><code class="docutils literal notranslate"><span class="pre">MultiHeadDotProductAttention.bias_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadDotProductAttention.out_bias_init"><code class="docutils literal notranslate"><span class="pre">MultiHeadDotProductAttention.out_bias_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadDotProductAttention.use_bias"><code class="docutils literal notranslate"><span class="pre">MultiHeadDotProductAttention.use_bias</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadDotProductAttention.attention_fn"><code class="docutils literal notranslate"><span class="pre">MultiHeadDotProductAttention.attention_fn</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadDotProductAttention.decode"><code class="docutils literal notranslate"><span class="pre">MultiHeadDotProductAttention.decode</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadDotProductAttention.normalize_qk"><code class="docutils literal notranslate"><span class="pre">MultiHeadDotProductAttention.normalize_qk</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadDotProductAttention.qk_attn_weights_einsum_cls"><code class="docutils literal notranslate"><span class="pre">MultiHeadDotProductAttention.qk_attn_weights_einsum_cls</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadDotProductAttention.attn_weights_value_einsum_cls"><code class="docutils literal notranslate"><span class="pre">MultiHeadDotProductAttention.attn_weights_value_einsum_cls</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadDotProductAttention.__call__"><code class="docutils literal notranslate"><span class="pre">MultiHeadDotProductAttention.__call__()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadAttention"><code class="docutils literal notranslate"><span class="pre">MultiHeadAttention</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadAttention.num_heads"><code class="docutils literal notranslate"><span class="pre">MultiHeadAttention.num_heads</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadAttention.dtype"><code class="docutils literal notranslate"><span class="pre">MultiHeadAttention.dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadAttention.param_dtype"><code class="docutils literal notranslate"><span class="pre">MultiHeadAttention.param_dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadAttention.qkv_features"><code class="docutils literal notranslate"><span class="pre">MultiHeadAttention.qkv_features</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadAttention.out_features"><code class="docutils literal notranslate"><span class="pre">MultiHeadAttention.out_features</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadAttention.broadcast_dropout"><code class="docutils literal notranslate"><span class="pre">MultiHeadAttention.broadcast_dropout</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadAttention.dropout_rate"><code class="docutils literal notranslate"><span class="pre">MultiHeadAttention.dropout_rate</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadAttention.deterministic"><code class="docutils literal notranslate"><span class="pre">MultiHeadAttention.deterministic</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadAttention.precision"><code class="docutils literal notranslate"><span class="pre">MultiHeadAttention.precision</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadAttention.kernel_init"><code class="docutils literal notranslate"><span class="pre">MultiHeadAttention.kernel_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadAttention.bias_init"><code class="docutils literal notranslate"><span class="pre">MultiHeadAttention.bias_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadAttention.use_bias"><code class="docutils literal notranslate"><span class="pre">MultiHeadAttention.use_bias</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadAttention.attention_fn"><code class="docutils literal notranslate"><span class="pre">MultiHeadAttention.attention_fn</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadAttention.decode"><code class="docutils literal notranslate"><span class="pre">MultiHeadAttention.decode</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadAttention.normalize_qk"><code class="docutils literal notranslate"><span class="pre">MultiHeadAttention.normalize_qk</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MultiHeadAttention.__call__"><code class="docutils literal notranslate"><span class="pre">MultiHeadAttention.__call__()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.SelfAttention"><code class="docutils literal notranslate"><span class="pre">SelfAttention</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.SelfAttention.__call__"><code class="docutils literal notranslate"><span class="pre">SelfAttention.__call__()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.dot_product_attention_weights"><code class="docutils literal notranslate"><span class="pre">dot_product_attention_weights()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.dot_product_attention"><code class="docutils literal notranslate"><span class="pre">dot_product_attention()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.make_attention_mask"><code class="docutils literal notranslate"><span class="pre">make_attention_mask()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.make_causal_mask"><code class="docutils literal notranslate"><span class="pre">make_causal_mask()</span></code></a></li> </ul> </li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#recurrent">Recurrent</a><ul class="nav section-nav flex-column"> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RNNCellBase"><code class="docutils literal notranslate"><span class="pre">RNNCellBase</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RNNCellBase.__call__"><code class="docutils literal notranslate"><span class="pre">RNNCellBase.__call__()</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RNNCellBase.initialize_carry"><code class="docutils literal notranslate"><span class="pre">RNNCellBase.initialize_carry()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.LSTMCell"><code class="docutils literal notranslate"><span class="pre">LSTMCell</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.LSTMCell.features"><code class="docutils literal notranslate"><span class="pre">LSTMCell.features</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.LSTMCell.gate_fn"><code class="docutils literal notranslate"><span class="pre">LSTMCell.gate_fn</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.LSTMCell.activation_fn"><code class="docutils literal notranslate"><span class="pre">LSTMCell.activation_fn</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.LSTMCell.kernel_init"><code class="docutils literal notranslate"><span class="pre">LSTMCell.kernel_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.LSTMCell.recurrent_kernel_init"><code class="docutils literal notranslate"><span class="pre">LSTMCell.recurrent_kernel_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.LSTMCell.bias_init"><code class="docutils literal notranslate"><span class="pre">LSTMCell.bias_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.LSTMCell.dtype"><code class="docutils literal notranslate"><span class="pre">LSTMCell.dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.LSTMCell.param_dtype"><code class="docutils literal notranslate"><span class="pre">LSTMCell.param_dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.LSTMCell.__call__"><code class="docutils literal notranslate"><span class="pre">LSTMCell.__call__()</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.LSTMCell.initialize_carry"><code class="docutils literal notranslate"><span class="pre">LSTMCell.initialize_carry()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.OptimizedLSTMCell"><code class="docutils literal notranslate"><span class="pre">OptimizedLSTMCell</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.OptimizedLSTMCell.gate_fn"><code class="docutils literal notranslate"><span class="pre">OptimizedLSTMCell.gate_fn</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.OptimizedLSTMCell.activation_fn"><code class="docutils literal notranslate"><span class="pre">OptimizedLSTMCell.activation_fn</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.OptimizedLSTMCell.kernel_init"><code class="docutils literal notranslate"><span class="pre">OptimizedLSTMCell.kernel_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.OptimizedLSTMCell.recurrent_kernel_init"><code class="docutils literal notranslate"><span class="pre">OptimizedLSTMCell.recurrent_kernel_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.OptimizedLSTMCell.bias_init"><code class="docutils literal notranslate"><span class="pre">OptimizedLSTMCell.bias_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.OptimizedLSTMCell.dtype"><code class="docutils literal notranslate"><span class="pre">OptimizedLSTMCell.dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.OptimizedLSTMCell.param_dtype"><code class="docutils literal notranslate"><span class="pre">OptimizedLSTMCell.param_dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.OptimizedLSTMCell.__call__"><code class="docutils literal notranslate"><span class="pre">OptimizedLSTMCell.__call__()</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.OptimizedLSTMCell.initialize_carry"><code class="docutils literal notranslate"><span class="pre">OptimizedLSTMCell.initialize_carry()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLSTMCell"><code class="docutils literal notranslate"><span class="pre">ConvLSTMCell</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLSTMCell.features"><code class="docutils literal notranslate"><span class="pre">ConvLSTMCell.features</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLSTMCell.kernel_size"><code class="docutils literal notranslate"><span class="pre">ConvLSTMCell.kernel_size</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLSTMCell.strides"><code class="docutils literal notranslate"><span class="pre">ConvLSTMCell.strides</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLSTMCell.padding"><code class="docutils literal notranslate"><span class="pre">ConvLSTMCell.padding</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLSTMCell.bias"><code class="docutils literal notranslate"><span class="pre">ConvLSTMCell.bias</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLSTMCell.dtype"><code class="docutils literal notranslate"><span class="pre">ConvLSTMCell.dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLSTMCell.param_dtype"><code class="docutils literal notranslate"><span class="pre">ConvLSTMCell.param_dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLSTMCell.__call__"><code class="docutils literal notranslate"><span class="pre">ConvLSTMCell.__call__()</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.ConvLSTMCell.initialize_carry"><code class="docutils literal notranslate"><span class="pre">ConvLSTMCell.initialize_carry()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.SimpleCell"><code class="docutils literal notranslate"><span class="pre">SimpleCell</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.SimpleCell.features"><code class="docutils literal notranslate"><span class="pre">SimpleCell.features</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.SimpleCell.activation_fn"><code class="docutils literal notranslate"><span class="pre">SimpleCell.activation_fn</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.SimpleCell.kernel_init"><code class="docutils literal notranslate"><span class="pre">SimpleCell.kernel_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.SimpleCell.recurrent_kernel_init"><code class="docutils literal notranslate"><span class="pre">SimpleCell.recurrent_kernel_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.SimpleCell.bias_init"><code class="docutils literal notranslate"><span class="pre">SimpleCell.bias_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.SimpleCell.dtype"><code class="docutils literal notranslate"><span class="pre">SimpleCell.dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.SimpleCell.param_dtype"><code class="docutils literal notranslate"><span class="pre">SimpleCell.param_dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.SimpleCell.residual"><code class="docutils literal notranslate"><span class="pre">SimpleCell.residual</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.SimpleCell.__call__"><code class="docutils literal notranslate"><span class="pre">SimpleCell.__call__()</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.SimpleCell.initialize_carry"><code class="docutils literal notranslate"><span class="pre">SimpleCell.initialize_carry()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GRUCell"><code class="docutils literal notranslate"><span class="pre">GRUCell</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GRUCell.features"><code class="docutils literal notranslate"><span class="pre">GRUCell.features</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GRUCell.gate_fn"><code class="docutils literal notranslate"><span class="pre">GRUCell.gate_fn</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GRUCell.activation_fn"><code class="docutils literal notranslate"><span class="pre">GRUCell.activation_fn</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GRUCell.kernel_init"><code class="docutils literal notranslate"><span class="pre">GRUCell.kernel_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GRUCell.recurrent_kernel_init"><code class="docutils literal notranslate"><span class="pre">GRUCell.recurrent_kernel_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GRUCell.bias_init"><code class="docutils literal notranslate"><span class="pre">GRUCell.bias_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GRUCell.dtype"><code class="docutils literal notranslate"><span class="pre">GRUCell.dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GRUCell.param_dtype"><code class="docutils literal notranslate"><span class="pre">GRUCell.param_dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GRUCell.__call__"><code class="docutils literal notranslate"><span class="pre">GRUCell.__call__()</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.GRUCell.initialize_carry"><code class="docutils literal notranslate"><span class="pre">GRUCell.initialize_carry()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MGUCell"><code class="docutils literal notranslate"><span class="pre">MGUCell</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MGUCell.features"><code class="docutils literal notranslate"><span class="pre">MGUCell.features</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MGUCell.gate_fn"><code class="docutils literal notranslate"><span class="pre">MGUCell.gate_fn</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MGUCell.activation_fn"><code class="docutils literal notranslate"><span class="pre">MGUCell.activation_fn</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MGUCell.kernel_init"><code class="docutils literal notranslate"><span class="pre">MGUCell.kernel_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MGUCell.recurrent_kernel_init"><code class="docutils literal notranslate"><span class="pre">MGUCell.recurrent_kernel_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MGUCell.forget_bias_init"><code class="docutils literal notranslate"><span class="pre">MGUCell.forget_bias_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MGUCell.activation_bias_init"><code class="docutils literal notranslate"><span class="pre">MGUCell.activation_bias_init</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MGUCell.dtype"><code class="docutils literal notranslate"><span class="pre">MGUCell.dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MGUCell.param_dtype"><code class="docutils literal notranslate"><span class="pre">MGUCell.param_dtype</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MGUCell.reset_gate"><code class="docutils literal notranslate"><span class="pre">MGUCell.reset_gate</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MGUCell.__call__"><code class="docutils literal notranslate"><span class="pre">MGUCell.__call__()</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.MGUCell.initialize_carry"><code class="docutils literal notranslate"><span class="pre">MGUCell.initialize_carry()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RNN"><code class="docutils literal notranslate"><span class="pre">RNN</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RNN.cell"><code class="docutils literal notranslate"><span class="pre">RNN.cell</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RNN.time_major"><code class="docutils literal notranslate"><span class="pre">RNN.time_major</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RNN.return_carry"><code class="docutils literal notranslate"><span class="pre">RNN.return_carry</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RNN.reverse"><code class="docutils literal notranslate"><span class="pre">RNN.reverse</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RNN.keep_order"><code class="docutils literal notranslate"><span class="pre">RNN.keep_order</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RNN.unroll"><code class="docutils literal notranslate"><span class="pre">RNN.unroll</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RNN.variable_axes"><code class="docutils literal notranslate"><span class="pre">RNN.variable_axes</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RNN.variable_broadcast"><code class="docutils literal notranslate"><span class="pre">RNN.variable_broadcast</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RNN.variable_carry"><code class="docutils literal notranslate"><span class="pre">RNN.variable_carry</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RNN.split_rngs"><code class="docutils literal notranslate"><span class="pre">RNN.split_rngs</span></code></a></li> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.RNN.__call__"><code class="docutils literal notranslate"><span class="pre">RNN.__call__()</span></code></a></li> </ul> </li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Bidirectional"><code class="docutils literal notranslate"><span class="pre">Bidirectional</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Bidirectional.__call__"><code class="docutils literal notranslate"><span class="pre">Bidirectional.__call__()</span></code></a></li> </ul> </li> </ul> </li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#batchapply">BatchApply</a><ul class="nav section-nav flex-column"> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.BatchApply"><code class="docutils literal notranslate"><span class="pre">BatchApply</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h4 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.BatchApply.__call__"><code class="docutils literal notranslate"><span class="pre">BatchApply.__call__()</span></code></a></li> </ul> </li> </ul> </li> </ul> </nav></div> </div></div> </div> <footer class="bd-footer-content"> <div class="bd-footer-content__inner container"> <div class="footer-item"> <p class="component-author"> By The Flax authors </p> </div> <div class="footer-item"> <p class="copyright"> © Copyright 2023, The Flax authors. <br/> </p> </div> <div class="footer-item"> </div> <div class="footer-item"> </div> </div> </footer> </main> </div> </div> <!-- Scripts loaded after <body> so the DOM is not blocked --> <script src="../../_static/scripts/bootstrap.js?digest=dfe6caa3a7d634c4db9b"></script> <script src="../../_static/scripts/pydata-sphinx-theme.js?digest=dfe6caa3a7d634c4db9b"></script> <footer class="bd-footer"> </footer> </body> </html>