CINXE.COM
Module
<!DOCTYPE html> <html lang="en" data-content_root="" > <head> <meta charset="utf-8" /> <meta name="viewport" content="width=device-width, initial-scale=1.0" /> <title>Module</title> <script data-cfasync="false"> document.documentElement.dataset.mode = localStorage.getItem("mode") || ""; document.documentElement.dataset.theme = localStorage.getItem("theme") || ""; </script> <!-- Loaded before other Sphinx assets --> <link href="../../_static/styles/theme.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" /> <link href="../../_static/styles/bootstrap.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" /> <link href="../../_static/styles/pydata-sphinx-theme.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" /> <link href="../../_static/vendor/fontawesome/6.5.2/css/all.min.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" /> <link rel="preload" as="font" type="font/woff2" crossorigin href="../../_static/vendor/fontawesome/6.5.2/webfonts/fa-solid-900.woff2" /> <link rel="preload" as="font" type="font/woff2" crossorigin href="../../_static/vendor/fontawesome/6.5.2/webfonts/fa-brands-400.woff2" /> <link rel="preload" as="font" type="font/woff2" crossorigin href="../../_static/vendor/fontawesome/6.5.2/webfonts/fa-regular-400.woff2" /> <link rel="stylesheet" type="text/css" href="../../_static/pygments.css" /> <link rel="stylesheet" type="text/css" href="../../_static/styles/sphinx-book-theme.css" /> <link rel="stylesheet" type="text/css" href="../../_static/mystnb.4510f1fc1dee50b3e5859aac5469c37c29e427902b24a333a5f9fcb2f0b3ac41.css" /> <link rel="stylesheet" type="text/css" href="../../_static/sphinx-design.5ea377869091fd0449014c60fc090103.min.css" /> <link rel="stylesheet" type="text/css" href="../../_static/css/flax_theme.css" /> <!-- Pre-loaded scripts that we'll load fully later --> <link rel="preload" as="script" href="../../_static/scripts/bootstrap.js?digest=dfe6caa3a7d634c4db9b" /> <link rel="preload" as="script" href="../../_static/scripts/pydata-sphinx-theme.js?digest=dfe6caa3a7d634c4db9b" /> <script src="../../_static/vendor/fontawesome/6.5.2/js/all.min.js?digest=dfe6caa3a7d634c4db9b"></script> <script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script> <script src="../../_static/jquery.js"></script> <script src="../../_static/underscore.js"></script> <script src="../../_static/_sphinx_javascript_frameworks_compat.js"></script> <script src="../../_static/doctools.js"></script> <script src="../../_static/sphinx_highlight.js"></script> <script src="../../_static/scripts/sphinx-book-theme.js"></script> <script src="../../_static/design-tabs.js"></script> <script>DOCUMENTATION_OPTIONS.pagename = 'api_reference/flax.linen/module';</script> <link rel="shortcut icon" href="../../_static/flax.png"/> <link rel="index" title="Index" href="../../genindex.html" /> <link rel="search" title="Search" href="../../search.html" /> <link rel="next" title="Init/Apply" href="init_apply.html" /> <link rel="prev" title="flax.linen" href="index.html" /> <meta name="viewport" content="width=device-width, initial-scale=1"/> <meta name="docsearch:language" content="en"/> <script async type="text/javascript" src="/_/static/javascript/readthedocs-addons.js"></script><meta name="readthedocs-project-slug" content="flax-linen" /><meta name="readthedocs-version-slug" content="latest" /><meta name="readthedocs-resolver-filename" content="/api_reference/flax.linen/module.html" /><meta name="readthedocs-http-status" content="200" /></head> <body data-bs-spy="scroll" data-bs-target=".bd-toc-nav" data-offset="180" data-bs-root-margin="0px 0px -60%" data-default-mode=""> <div id="pst-skip-link" class="skip-link d-print-none"><a href="#main-content">Skip to main content</a></div> <div id="pst-scroll-pixel-helper"></div> <button type="button" class="btn rounded-pill" id="pst-back-to-top"> <i class="fa-solid fa-arrow-up"></i>Back to top</button> <input type="checkbox" class="sidebar-toggle" id="pst-primary-sidebar-checkbox"/> <label class="overlay overlay-primary" for="pst-primary-sidebar-checkbox"></label> <input type="checkbox" class="sidebar-toggle" id="pst-secondary-sidebar-checkbox"/> <label class="overlay overlay-secondary" for="pst-secondary-sidebar-checkbox"></label> <div class="search-button__wrapper"> <div class="search-button__overlay"></div> <div class="search-button__search-container"> <form class="bd-search d-flex align-items-center" action="../../search.html" method="get"> <i class="fa-solid fa-magnifying-glass"></i> <input type="search" class="form-control" name="q" id="search-input" placeholder="Search..." aria-label="Search..." autocomplete="off" autocorrect="off" autocapitalize="off" spellcheck="false"/> <span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd>K</kbd></span> </form></div> </div> <div class="pst-async-banner-revealer d-none"> <aside id="bd-header-version-warning" class="d-none d-print-none" aria-label="Version warning"></aside> </div> <aside class="bd-header-announcement" aria-label="Announcement"> <div class="bd-header-announcement__content"> <a href="https://flax.readthedocs.io/en/latest/index.html" style="text-decoration: none; color: white;" > This site covers the old Flax Linen API. <span style="color: lightgray;">[Explore the new <b>Flax NNX</b> API ✨]</span> </a> </div> </aside> <header class="bd-header navbar navbar-expand-lg bd-navbar d-print-none"> </header> <div class="bd-container"> <div class="bd-container__inner bd-page-width"> <div class="bd-sidebar-primary bd-sidebar"> <div class="sidebar-header-items sidebar-primary__section"> </div> <div class="sidebar-primary-items__start sidebar-primary__section"> <div class="sidebar-primary-item"> <a class="navbar-brand logo" href="../../index.html"> <img src="../../_static/flax.png" class="logo__image only-light" alt=" - Home"/> <script>document.write(`<img src="../../_static/flax.png" class="logo__image only-dark" alt=" - Home"/>`);</script> </a></div> <div class="sidebar-primary-item"> <script> document.write(` <button class="btn search-button-field search-button__button" title="Search" aria-label="Search" data-bs-placement="bottom" data-bs-toggle="tooltip"> <i class="fa-solid fa-magnifying-glass"></i> <span class="search-button__default-text">Search</span> <span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd class="kbd-shortcut__modifier">K</kbd></span> </button> `); </script></div> <div class="sidebar-primary-item"><nav class="bd-links bd-docs-nav" aria-label="Main"> <div class="bd-toc-item navbar-nav active"> <ul class="current nav bd-sidenav"> <li class="toctree-l1"><a class="reference internal" href="../../quick_start.html">Quick start</a></li> <li class="toctree-l1"><a class="reference internal" href="../../guides/flax_fundamentals/flax_basics.html">Flax Basics</a></li> <li class="toctree-l1 has-children"><a class="reference internal" href="../../guides/index.html">Guides</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l2 has-children"><a class="reference internal" href="../../guides/flax_fundamentals/index.html">Flax fundamentals</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l3"><a class="reference external" href="https://jax.readthedocs.io/en/latest/jax-101/index.html">JAX 101</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/flax_fundamentals/flax_basics.html">Flax Basics</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/flax_fundamentals/state_params.html">Managing Parameters and State</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/flax_fundamentals/setup_or_nncompact.html"><code class="docutils literal notranslate"><span class="pre">setup</span></code> vs <code class="docutils literal notranslate"><span class="pre">compact</span></code></a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/flax_fundamentals/arguments.html">Dealing with Flax Module arguments</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/flax_fundamentals/rng_guide.html">Randomness and PRNGs in Flax</a></li> </ul> </details></li> <li class="toctree-l2 has-children"><a class="reference internal" href="../../guides/data_preprocessing/index.html">Data preprocessing</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l3"><a class="reference internal" href="../../guides/data_preprocessing/full_eval.html">Processing the entire Dataset</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/data_preprocessing/loading_datasets.html">Loading datasets</a></li> </ul> </details></li> <li class="toctree-l2 has-children"><a class="reference internal" href="../../guides/training_techniques/index.html">Training techniques</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l3"><a class="reference internal" href="../../guides/training_techniques/batch_norm.html">Batch normalization</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/training_techniques/dropout.html">Dropout</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/training_techniques/lr_schedule.html">Learning rate scheduling</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/training_techniques/transfer_learning.html">Transfer learning</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/training_techniques/use_checkpointing.html">Save and load checkpoints</a></li> </ul> </details></li> <li class="toctree-l2 has-children"><a class="reference internal" href="../../guides/parallel_training/index.html">Parallel training</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l3"><a class="reference internal" href="../../guides/parallel_training/ensembling.html">Ensembling on multiple devices</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/parallel_training/flax_on_pjit.html">Scale up Flax Modules on multiple devices</a></li> </ul> </details></li> <li class="toctree-l2 has-children"><a class="reference internal" href="../../guides/model_inspection/index.html">Model inspection</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l3"><a class="reference internal" href="../../guides/model_inspection/model_surgery.html">Model surgery</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/model_inspection/extracting_intermediates.html">Extracting intermediate values</a></li> </ul> </details></li> <li class="toctree-l2 has-children"><a class="reference internal" href="../../guides/converting_and_upgrading/index.html">Converting and upgrading</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l3"><a class="reference internal" href="../../guides/converting_and_upgrading/haiku_migration_guide.html">Migrating from Haiku to Flax</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/converting_and_upgrading/convert_pytorch_to_flax.html">Convert PyTorch models to Flax</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/converting_and_upgrading/orbax_upgrade_guide.html">Migrate checkpointing to Orbax</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/converting_and_upgrading/optax_update_guide.html">Upgrading my codebase to Optax</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/converting_and_upgrading/linen_upgrade_guide.html">Upgrading my codebase to Linen</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/converting_and_upgrading/rnncell_upgrade_guide.html">RNNCellBase Upgrade Guide</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/converting_and_upgrading/regular_dict_upgrade_guide.html">Migrate to regular dicts</a></li> </ul> </details></li> <li class="toctree-l2 has-children"><a class="reference internal" href="../../guides/quantization/index.html">Quantization</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l3"><a class="reference internal" href="../../guides/quantization/fp8_basics.html">User Guide on Using FP8</a></li> </ul> </details></li> <li class="toctree-l2"><a class="reference internal" href="../../guides/flax_sharp_bits.html">The Sharp Bits</a></li> </ul> </details></li> <li class="toctree-l1 has-children"><a class="reference internal" href="../../examples/index.html">Examples</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l2"><a class="reference internal" href="../../examples/core_examples.html">Core examples</a></li> <li class="toctree-l2"><a class="reference internal" href="../../examples/google_research_examples.html">Google Research examples</a></li> <li class="toctree-l2"><a class="reference internal" href="../../examples/repositories_that_use_flax.html">Repositories that use Flax</a></li> <li class="toctree-l2"><a class="reference internal" href="../../examples/community_examples.html">Community examples</a></li> </ul> </details></li> <li class="toctree-l1"><a class="reference internal" href="../../glossary.html">Glossary</a></li> <li class="toctree-l1"><a class="reference internal" href="../../faq.html">Frequently Asked Questions (FAQ)</a></li> <li class="toctree-l1 has-children"><a class="reference internal" href="../../developer_notes/index.html">Developer notes</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l2"><a class="reference internal" href="../../developer_notes/module_lifecycle.html">The Flax Module lifecycle</a></li> <li class="toctree-l2"><a class="reference internal" href="../../developer_notes/lift.html">Lifted transformations</a></li> <li class="toctree-l2"><a class="reference external" href="https://github.com/google/flax/tree/main/docs/flip">FLIPs</a></li> </ul> </details></li> <li class="toctree-l1"><a class="reference internal" href="../../philosophy.html">The Flax philosophy</a></li> <li class="toctree-l1"><a class="reference internal" href="../../contributing.html">How to contribute</a></li> <li class="toctree-l1 current active has-children"><a class="reference internal" href="../index.html">API Reference</a><details open="open"><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul class="current"> <li class="toctree-l2"><a class="reference internal" href="../flax.config.html">flax.config package</a></li> <li class="toctree-l2"><a class="reference internal" href="../flax.core.frozen_dict.html">flax.core.frozen_dict package</a></li> <li class="toctree-l2"><a class="reference internal" href="../flax.cursor.html">flax.cursor package</a></li> <li class="toctree-l2"><a class="reference internal" href="../flax.errors.html">flax.errors package</a></li> <li class="toctree-l2"><a class="reference internal" href="../flax.jax_utils.html">flax.jax_utils package</a></li> <li class="toctree-l2 current active has-children"><a class="reference internal" href="index.html">flax.linen</a><details open="open"><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul class="current"> <li class="toctree-l3 current active"><a class="current reference internal" href="#">Module</a></li> <li class="toctree-l3"><a class="reference internal" href="init_apply.html">Init/Apply</a></li> <li class="toctree-l3"><a class="reference internal" href="layers.html">Layers</a></li> <li class="toctree-l3"><a class="reference internal" href="activation_functions.html">Activation functions</a></li> <li class="toctree-l3"><a class="reference internal" href="initializers.html">Initializers</a></li> <li class="toctree-l3"><a class="reference internal" href="transformations.html">Transformations</a></li> <li class="toctree-l3"><a class="reference internal" href="inspection.html">Inspection</a></li> <li class="toctree-l3"><a class="reference internal" href="variable.html">Variable dictionary</a></li> <li class="toctree-l3"><a class="reference internal" href="spmd.html">SPMD</a></li> <li class="toctree-l3"><a class="reference internal" href="decorators.html">Decorators</a></li> <li class="toctree-l3"><a class="reference internal" href="profiling.html">Profiling</a></li> </ul> </details></li> <li class="toctree-l2"><a class="reference internal" href="../flax.serialization.html">flax.serialization package</a></li> <li class="toctree-l2"><a class="reference internal" href="../flax.struct.html">flax.struct package</a></li> <li class="toctree-l2"><a class="reference internal" href="../flax.traceback_util.html">flax.traceback_util package</a></li> <li class="toctree-l2"><a class="reference internal" href="../flax.training.html">flax.training package</a></li> <li class="toctree-l2"><a class="reference internal" href="../flax.traverse_util.html">flax.traverse_util package</a></li> </ul> </details></li> <li class="toctree-l1"><a class="reference external" href="https://flax.readthedocs.io/en/latest/index.html">Flax NNX</a></li> </ul> </div> </nav></div> </div> <div class="sidebar-primary-items__end sidebar-primary__section"> </div> <div id="rtd-footer-container"></div> </div> <main id="main-content" class="bd-main" role="main"> <div class="sbt-scroll-pixel-helper"></div> <div class="bd-content"> <div class="bd-article-container"> <div class="bd-header-article d-print-none"> <div class="header-article-items header-article__inner"> <div class="header-article-items__start"> <div class="header-article-item"><button class="sidebar-toggle primary-toggle btn btn-sm" title="Toggle primary sidebar" data-bs-placement="bottom" data-bs-toggle="tooltip"> <span class="fa-solid fa-bars"></span> </button></div> </div> <div class="header-article-items__end"> <div class="header-article-item"> <div class="article-header-buttons"> <a href="https://github.com/google/flax" target="_blank" class="btn btn-sm btn-source-repository-button" title="Source repository" data-bs-placement="bottom" data-bs-toggle="tooltip" > <span class="btn__icon-container"> <i class="fab fa-github"></i> </span> </a> <div class="dropdown dropdown-download-buttons"> <button class="btn dropdown-toggle" type="button" data-bs-toggle="dropdown" aria-expanded="false" aria-label="Download this page"> <i class="fas fa-download"></i> </button> <ul class="dropdown-menu"> <li><a href="../../_sources/api_reference/flax.linen/module.rst" target="_blank" class="btn btn-sm btn-download-source-button dropdown-item" title="Download source file" data-bs-placement="left" data-bs-toggle="tooltip" > <span class="btn__icon-container"> <i class="fas fa-file"></i> </span> <span class="btn__text-container">.rst</span> </a> </li> <li> <button onclick="window.print()" class="btn btn-sm btn-download-pdf-button dropdown-item" title="Print to PDF" data-bs-placement="left" data-bs-toggle="tooltip" > <span class="btn__icon-container"> <i class="fas fa-file-pdf"></i> </span> <span class="btn__text-container">.pdf</span> </button> </li> </ul> </div> <button onclick="toggleFullScreen()" class="btn btn-sm btn-fullscreen-button" title="Fullscreen mode" data-bs-placement="bottom" data-bs-toggle="tooltip" > <span class="btn__icon-container"> <i class="fas fa-expand"></i> </span> </button> <script> document.write(` <button class="btn btn-sm nav-link pst-navbar-icon theme-switch-button" title="light/dark" aria-label="light/dark" data-bs-placement="bottom" data-bs-toggle="tooltip"> <i class="theme-switch fa-solid fa-sun fa-lg" data-mode="light"></i> <i class="theme-switch fa-solid fa-moon fa-lg" data-mode="dark"></i> <i class="theme-switch fa-solid fa-circle-half-stroke fa-lg" data-mode="auto"></i> </button> `); </script> <script> document.write(` <button class="btn btn-sm pst-navbar-icon search-button search-button__button" title="Search" aria-label="Search" data-bs-placement="bottom" data-bs-toggle="tooltip"> <i class="fa-solid fa-magnifying-glass fa-lg"></i> </button> `); </script> <button class="sidebar-toggle secondary-toggle btn btn-sm" title="Toggle secondary sidebar" data-bs-placement="bottom" data-bs-toggle="tooltip"> <span class="fa-solid fa-list"></span> </button> </div></div> </div> </div> </div> <div id="jb-print-docs-body" class="onlyprint"> <h1>Module</h1> <!-- Table of contents --> <div id="print-main-content"> <div id="jb-print-toc"> <div> <h2> Contents </h2> </div> <nav aria-label="Page"> <ul class="visible nav section-nav flex-column"> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Module"><code class="docutils literal notranslate"><span class="pre">Module</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Module.__setattr__"><code class="docutils literal notranslate"><span class="pre">Module.__setattr__()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Module.apply"><code class="docutils literal notranslate"><span class="pre">Module.apply()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Module.bind"><code class="docutils literal notranslate"><span class="pre">Module.bind()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Module.copy"><code class="docutils literal notranslate"><span class="pre">Module.copy()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Module.get_variable"><code class="docutils literal notranslate"><span class="pre">Module.get_variable()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Module.has_rng"><code class="docutils literal notranslate"><span class="pre">Module.has_rng()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Module.has_variable"><code class="docutils literal notranslate"><span class="pre">Module.has_variable()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Module.init"><code class="docutils literal notranslate"><span class="pre">Module.init()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Module.init_with_output"><code class="docutils literal notranslate"><span class="pre">Module.init_with_output()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Module.is_initializing"><code class="docutils literal notranslate"><span class="pre">Module.is_initializing()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Module.is_mutable_collection"><code class="docutils literal notranslate"><span class="pre">Module.is_mutable_collection()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Module.lazy_init"><code class="docutils literal notranslate"><span class="pre">Module.lazy_init()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Module.make_rng"><code class="docutils literal notranslate"><span class="pre">Module.make_rng()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Module.module_paths"><code class="docutils literal notranslate"><span class="pre">Module.module_paths()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Module.param"><code class="docutils literal notranslate"><span class="pre">Module.param()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Module.path"><code class="docutils literal notranslate"><span class="pre">Module.path</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Module.perturb"><code class="docutils literal notranslate"><span class="pre">Module.perturb()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Module.put_variable"><code class="docutils literal notranslate"><span class="pre">Module.put_variable()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Module.setup"><code class="docutils literal notranslate"><span class="pre">Module.setup()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Module.sow"><code class="docutils literal notranslate"><span class="pre">Module.sow()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Module.tabulate"><code class="docutils literal notranslate"><span class="pre">Module.tabulate()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Module.unbind"><code class="docutils literal notranslate"><span class="pre">Module.unbind()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Module.variable"><code class="docutils literal notranslate"><span class="pre">Module.variable()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Module.variables"><code class="docutils literal notranslate"><span class="pre">Module.variables</span></code></a></li> </ul> </li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.apply"><code class="docutils literal notranslate"><span class="pre">apply()</span></code></a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.init"><code class="docutils literal notranslate"><span class="pre">init()</span></code></a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.init_with_output"><code class="docutils literal notranslate"><span class="pre">init_with_output()</span></code></a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.intercept_methods"><code class="docutils literal notranslate"><span class="pre">intercept_methods()</span></code></a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.share_scope"><code class="docutils literal notranslate"><span class="pre">share_scope()</span></code></a></li> </ul> </nav> </div> </div> </div> <div id="searchbox"></div> <article class="bd-article"> <div class="section" id="module-flax.linen"> <span id="module"></span><h1>Module<a class="headerlink" href="#module-flax.linen" title="Permalink to this heading">#</a></h1> <p>The Flax Module system.</p> <dl class="py class"> <dt class="sig sig-object py" id="flax.linen.Module"> <em class="property"><span class="pre">class</span><span class="w"> </span></em><span class="sig-prename descclassname"><span class="pre">flax.linen.</span></span><span class="sig-name descname"><span class="pre">Module</span></span><a class="reference internal" href="../../_modules/flax/linen/module.html#Module"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.Module" title="Permalink to this definition">#</a></dt> <dd><p>Base class for all neural network modules.</p> <p>Layers and models should subclass this class.</p> <p>All Flax Modules are Python 3.7 <a class="reference external" href="https://docs.python.org/3/library/dataclasses.html">dataclasses</a>. Since dataclasses take over <code class="docutils literal notranslate"><span class="pre">__init__</span></code>, you should instead override <a class="reference internal" href="#flax.linen.Module.setup" title="flax.linen.Module.setup"><code class="xref py py-meth docutils literal notranslate"><span class="pre">setup()</span></code></a>, which is automatically called to initialize the module.</p> <p>Modules can contain submodules, and in this way can be nested in a tree structure. Submodels can be assigned as regular attributes inside the <a class="reference internal" href="#flax.linen.Module.setup" title="flax.linen.Module.setup"><code class="xref py py-meth docutils literal notranslate"><span class="pre">setup()</span></code></a> method.</p> <p>You can define arbitrary “forward pass” methods on your Module subclass. While no methods are special-cased, <code class="docutils literal notranslate"><span class="pre">__call__</span></code> is a popular choice because it allows you to use module instances as if they are functions:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="kn">from</span> <span class="nn">flax</span> <span class="kn">import</span> <span class="n">linen</span> <span class="k">as</span> <span class="n">nn</span> <span class="gp">>>> </span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">Tuple</span> <span class="gp">>>> </span><span class="k">class</span> <span class="nc">Module</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="gp">... </span> <span class="n">features</span><span class="p">:</span> <span class="n">Tuple</span><span class="p">[</span><span class="nb">int</span><span class="p">,</span> <span class="o">...</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="mi">16</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span> <span class="go">... def setup(self):</span> <span class="go">... self.dense1 = nn.Dense(self.features[0])</span> <span class="go">... self.dense2 = nn.Dense(self.features[1])</span> <span class="go">... def __call__(self, x):</span> <span class="go">... return self.dense2(nn.relu(self.dense1(x)))</span> </pre></div> </div> <p>Optionally, for more concise module implementations where submodules definitions are co-located with their usage, you can use the <a class="reference internal" href="decorators.html#flax.linen.compact" title="flax.linen.compact"><code class="xref py py-meth docutils literal notranslate"><span class="pre">compact()</span></code></a> wrapper.</p> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.Module.__setattr__"> <span class="sig-name descname"><span class="pre">__setattr__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">name</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">val</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/module.html#Module.__setattr__"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.Module.__setattr__" title="Permalink to this definition">#</a></dt> <dd><p>Sets an attribute on this Module.</p> <p>We overload setattr solely to support pythonic naming via assignment of submodules in the special <a class="reference internal" href="#flax.linen.Module.setup" title="flax.linen.Module.setup"><code class="xref py py-meth docutils literal notranslate"><span class="pre">setup()</span></code></a> function:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="bp">self</span><span class="o">.</span><span class="n">submodule_name</span> <span class="o">=</span> <span class="n">MyModule</span><span class="p">(</span><span class="o">...</span><span class="p">)</span> </pre></div> </div> <p>We also support lists and other general pytrees, e.g.:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="bp">self</span><span class="o">.</span><span class="n">submodules</span> <span class="o">=</span> <span class="p">[</span><span class="n">MyModule0</span><span class="p">(</span><span class="o">..</span><span class="p">),</span> <span class="n">MyModule1</span><span class="p">(</span><span class="o">..</span><span class="p">),</span> <span class="o">...</span><span class="p">]</span> </pre></div> </div> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>name</strong> – Attribute to set.</p></li> <li><p><strong>val</strong> – Value of the attribute.</p></li> </ul> </dd> </dl> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.Module.apply"> <span class="sig-name descname"><span class="pre">apply</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">variables</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">rngs</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">method</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">mutable</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">capture_intermediates</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/module.html#Module.apply"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.Module.apply" title="Permalink to this definition">#</a></dt> <dd><p>Applies a module method to variables and returns output and modified variables.</p> <p>Note that <code class="docutils literal notranslate"><span class="pre">method</span></code> should be set if one would like to call <code class="docutils literal notranslate"><span class="pre">apply</span></code> on a different class method than <code class="docutils literal notranslate"><span class="pre">__call__</span></code>. For instance, suppose a Transformer modules has a method called <code class="docutils literal notranslate"><span class="pre">encode</span></code>, then the following calls <code class="docutils literal notranslate"><span class="pre">apply</span></code> on that method:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="kn">import</span> <span class="nn">flax.linen</span> <span class="k">as</span> <span class="nn">nn</span> <span class="gp">>>> </span><span class="kn">import</span> <span class="nn">jax</span><span class="o">,</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span> <span class="gp">>>> </span><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> <span class="gp">>>> </span><span class="k">class</span> <span class="nc">Transformer</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="gp">... </span> <span class="k">def</span> <span class="nf">encode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="gp">... </span> <span class="o">...</span> <span class="gp">>>> </span><span class="n">x</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">16</span><span class="p">,</span> <span class="mi">9</span><span class="p">))</span> <span class="gp">>>> </span><span class="n">model</span> <span class="o">=</span> <span class="n">Transformer</span><span class="p">()</span> <span class="gp">>>> </span><span class="n">variables</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">x</span><span class="p">,</span> <span class="n">method</span><span class="o">=</span><span class="n">Transformer</span><span class="o">.</span><span class="n">encode</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">encoded</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">method</span><span class="o">=</span><span class="n">Transformer</span><span class="o">.</span><span class="n">encode</span><span class="p">)</span> </pre></div> </div> <p>If a function instance is provided, the unbound function is used. For instance, the example below is equivalent to the one above:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="n">encoded</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">method</span><span class="o">=</span><span class="n">model</span><span class="o">.</span><span class="n">encode</span><span class="p">)</span> </pre></div> </div> <p>You can also pass a string to a callable attribute of the module. For example, the previous can be written as:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="n">encoded</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">method</span><span class="o">=</span><span class="s1">'encode'</span><span class="p">)</span> </pre></div> </div> <p>Note <code class="docutils literal notranslate"><span class="pre">method</span></code> can also be a function that is not defined in <code class="docutils literal notranslate"><span class="pre">Transformer</span></code>. In that case, the function should have at least one argument representing an instance of the Module class:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="k">def</span> <span class="nf">other_fn</span><span class="p">(</span><span class="n">instance</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="gp">... </span> <span class="c1"># instance.some_module_attr(...)</span> <span class="gp">... </span> <span class="n">instance</span><span class="o">.</span><span class="n">encode</span> <span class="gp">... </span> <span class="o">...</span> <span class="gp">>>> </span><span class="n">model</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">method</span><span class="o">=</span><span class="n">other_fn</span><span class="p">)</span> </pre></div> </div> <p>If you pass a single <code class="docutils literal notranslate"><span class="pre">PRNGKey</span></code>, Flax will use it to feed the <code class="docutils literal notranslate"><span class="pre">'params'</span></code> RNG stream. If you want to use a different RNG stream or need to use multiple streams, you can pass a dictionary mapping each RNG stream name to its corresponding <code class="docutils literal notranslate"><span class="pre">PRNGKey</span></code> to <code class="docutils literal notranslate"><span class="pre">apply</span></code>. If <code class="docutils literal notranslate"><span class="pre">self.make_rng(name)</span></code> is called on an RNG stream name that isn’t passed by the user, it will default to using the <code class="docutils literal notranslate"><span class="pre">'params'</span></code> RNG stream.</p> <p>Example:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="k">class</span> <span class="nc">Foo</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="gp">... </span> <span class="nd">@nn</span><span class="o">.</span><span class="n">compact</span> <span class="gp">... </span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">add_noise</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span> <span class="gp">... </span> <span class="n">x</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">16</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="gp">... </span> <span class="n">x</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="gp">...</span> <span class="gp">... </span> <span class="k">if</span> <span class="n">add_noise</span><span class="p">:</span> <span class="gp">... </span> <span class="c1"># Add gaussian noise</span> <span class="gp">... </span> <span class="n">noise_key</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">make_rng</span><span class="p">(</span><span class="s1">'noise'</span><span class="p">)</span> <span class="gp">... </span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">noise_key</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="gp">...</span> <span class="gp">... </span> <span class="k">return</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">1</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">x</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">7</span><span class="p">))</span> <span class="gp">>>> </span><span class="n">module</span> <span class="o">=</span> <span class="n">Foo</span><span class="p">()</span> <span class="gp">>>> </span><span class="n">rngs</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'params'</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="s1">'noise'</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">1</span><span class="p">)}</span> <span class="gp">>>> </span><span class="n">variables</span> <span class="o">=</span> <span class="n">module</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">rngs</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">out0</span> <span class="o">=</span> <span class="n">module</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">add_noise</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">rngs</span><span class="o">=</span><span class="n">rngs</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">rngs</span><span class="p">[</span><span class="s1">'noise'</span><span class="p">]</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">out1</span> <span class="o">=</span> <span class="n">module</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">add_noise</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">rngs</span><span class="o">=</span><span class="n">rngs</span><span class="p">)</span> <span class="gp">>>> </span><span class="c1"># different output (key(1) vs key(0))</span> <span class="gp">>>> </span><span class="n">np</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_raises</span><span class="p">(</span><span class="ne">AssertionError</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_allclose</span><span class="p">,</span> <span class="n">out0</span><span class="p">,</span> <span class="n">out1</span><span class="p">)</span> <span class="gp">>>> </span><span class="k">del</span> <span class="n">rngs</span><span class="p">[</span><span class="s1">'noise'</span><span class="p">]</span> <span class="gp">>>> </span><span class="c1"># self.make_rng('noise') will default to using the 'params' RNG stream</span> <span class="gp">>>> </span><span class="n">out2</span> <span class="o">=</span> <span class="n">module</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">add_noise</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">rngs</span><span class="o">=</span><span class="n">rngs</span><span class="p">)</span> <span class="gp">>>> </span><span class="c1"># same output (key(0))</span> <span class="gp">>>> </span><span class="n">np</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_allclose</span><span class="p">(</span><span class="n">out1</span><span class="p">,</span> <span class="n">out2</span><span class="p">)</span> <span class="gp">>>> </span><span class="c1"># passing in a single key is equivalent to passing in {'params': key}</span> <span class="gp">>>> </span><span class="n">out3</span> <span class="o">=</span> <span class="n">module</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">add_noise</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">rngs</span><span class="o">=</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">))</span> <span class="gp">>>> </span><span class="c1"># same output (key(0))</span> <span class="gp">>>> </span><span class="n">np</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_allclose</span><span class="p">(</span><span class="n">out2</span><span class="p">,</span> <span class="n">out3</span><span class="p">)</span> </pre></div> </div> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>variables</strong> – A dictionary containing variables keyed by variable collections. See <a class="reference internal" href="variable.html#module-flax.core.variables" title="flax.core.variables"><code class="xref py py-mod docutils literal notranslate"><span class="pre">flax.core.variables</span></code></a> for more details about variables.</p></li> <li><p><strong>*args</strong> – Named arguments passed to the specified apply method.</p></li> <li><p><strong>rngs</strong> – a dict of PRNGKeys to initialize the PRNG sequences. The “params” PRNG sequence is used to initialize parameters.</p></li> <li><p><strong>method</strong> – A function to call apply on. This is generally a function in the module. If provided, applies this method. If not provided, applies the <code class="docutils literal notranslate"><span class="pre">__call__</span></code> method of the module. A string can also be provided to specify a method by name.</p></li> <li><p><strong>mutable</strong> – Can be bool, str, or list. Specifies which collections should be treated as mutable: <code class="docutils literal notranslate"><span class="pre">bool</span></code>: all/no collections are mutable. <code class="docutils literal notranslate"><span class="pre">str</span></code>: The name of a single mutable collection. <code class="docutils literal notranslate"><span class="pre">list</span></code>: A list of names of mutable collections.</p></li> <li><p><strong>capture_intermediates</strong> – If <code class="docutils literal notranslate"><span class="pre">True</span></code>, captures intermediate return values of all Modules inside the “intermediates” collection. By default, only the return values of all <code class="docutils literal notranslate"><span class="pre">__call__</span></code> methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.</p></li> <li><p><strong>**kwargs</strong> – Keyword arguments passed to the specified apply method.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>If <code class="docutils literal notranslate"><span class="pre">mutable</span></code> is False, returns output. If any collections are mutable, returns <code class="docutils literal notranslate"><span class="pre">(output,</span> <span class="pre">vars)</span></code>, where <code class="docutils literal notranslate"><span class="pre">vars</span></code> are is a dict of the modified collections.</p> </dd> </dl> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.Module.bind"> <span class="sig-name descname"><span class="pre">bind</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">variables</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">rngs</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">mutable</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/module.html#Module.bind"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.Module.bind" title="Permalink to this definition">#</a></dt> <dd><p>Creates an interactive Module instance by binding variables and RNGs.</p> <p><code class="docutils literal notranslate"><span class="pre">bind</span></code> provides an “interactive” instance of a Module directly without transforming a function with <code class="docutils literal notranslate"><span class="pre">apply</span></code>. This is particularly useful for debugging and interactive use cases like notebooks where a function would limit the ability to split up code into different cells.</p> <p>Once the variables (and optionally RNGs) are bound to a <code class="docutils literal notranslate"><span class="pre">Module</span></code> it becomes a stateful object. Note that idiomatic JAX is functional and therefore an interactive instance does not mix well with vanilla JAX APIs. <code class="docutils literal notranslate"><span class="pre">bind()</span></code> should only be used for interactive experimentation, and in all other cases we strongly encourage users to use <code class="docutils literal notranslate"><span class="pre">apply()</span></code> instead.</p> <p>Example:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="kn">import</span> <span class="nn">jax</span> <span class="gp">>>> </span><span class="kn">import</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span> <span class="gp">>>> </span><span class="kn">import</span> <span class="nn">flax.linen</span> <span class="k">as</span> <span class="nn">nn</span> <span class="gp">>>> </span><span class="k">class</span> <span class="nc">AutoEncoder</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="gp">... </span> <span class="k">def</span> <span class="nf">setup</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="gp">... </span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">3</span><span class="p">)</span> <span class="gp">... </span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">5</span><span class="p">)</span> <span class="gp">...</span> <span class="gp">... </span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="gp">... </span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="p">(</span><span class="n">x</span><span class="p">))</span> <span class="gp">>>> </span><span class="n">x</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">16</span><span class="p">,</span> <span class="mi">9</span><span class="p">))</span> <span class="gp">>>> </span><span class="n">ae</span> <span class="o">=</span> <span class="n">AutoEncoder</span><span class="p">()</span> <span class="gp">>>> </span><span class="n">variables</span> <span class="o">=</span> <span class="n">ae</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">model</span> <span class="o">=</span> <span class="n">ae</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><span class="n">variables</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">z</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">encoder</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">x_reconstructed</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">decoder</span><span class="p">(</span><span class="n">z</span><span class="p">)</span> </pre></div> </div> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>variables</strong> – A dictionary containing variables keyed by variable collections. See <a class="reference internal" href="variable.html#module-flax.core.variables" title="flax.core.variables"><code class="xref py py-mod docutils literal notranslate"><span class="pre">flax.core.variables</span></code></a> for more details about variables.</p></li> <li><p><strong>*args</strong> – Named arguments (not used).</p></li> <li><p><strong>rngs</strong> – a dict of PRNGKeys to initialize the PRNG sequences.</p></li> <li><p><strong>mutable</strong> – Can be bool, str, or list. Specifies which collections should be treated as mutable: <code class="docutils literal notranslate"><span class="pre">bool</span></code>: all/no collections are mutable. <code class="docutils literal notranslate"><span class="pre">str</span></code>: The name of a single mutable collection. <code class="docutils literal notranslate"><span class="pre">list</span></code>: A list of names of mutable collections.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>A copy of this instance with bound variables and RNGs.</p> </dd> </dl> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.Module.copy"> <span class="sig-name descname"><span class="pre">copy</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">*</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">parent=<flax.linen.module._Sentinel</span> <span class="pre">object></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">name=None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">**updates</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/module.html#Module.copy"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.Module.copy" title="Permalink to this definition">#</a></dt> <dd><p>Creates a copy of this Module, with optionally updated arguments.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>parent</strong> – The parent of the copy. By default the current module is taken as parent if not explicitly specified.</p></li> <li><p><strong>name</strong> – A new name for the copied Module, by default a new automatic name will be given.</p></li> <li><p><strong>**updates</strong> – Attribute updates.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>A copy of the this Module with the updated name, parent, and attributes.</p> </dd> </dl> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.Module.get_variable"> <span class="sig-name descname"><span class="pre">get_variable</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">col</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">name</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">default</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/module.html#Module.get_variable"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.Module.get_variable" title="Permalink to this definition">#</a></dt> <dd><p>Retrieves the value of a Variable.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>col</strong> – the variable collection.</p></li> <li><p><strong>name</strong> – the name of the variable.</p></li> <li><p><strong>default</strong> – the default value to return if the variable does not exist in this scope.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>The value of the input variable, of the default value if the variable doesn’t exist in this scope.</p> </dd> </dl> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.Module.has_rng"> <span class="sig-name descname"><span class="pre">has_rng</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">name</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/module.html#Module.has_rng"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.Module.has_rng" title="Permalink to this definition">#</a></dt> <dd><p>Returns true if a PRNGSequence with name <code class="docutils literal notranslate"><span class="pre">name</span></code> exists.</p> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.Module.has_variable"> <span class="sig-name descname"><span class="pre">has_variable</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">col</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">name</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/module.html#Module.has_variable"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.Module.has_variable" title="Permalink to this definition">#</a></dt> <dd><p>Checks if a variable of given collection and name exists in this Module.</p> <p>See <a class="reference internal" href="variable.html#module-flax.core.variables" title="flax.core.variables"><code class="xref py py-mod docutils literal notranslate"><span class="pre">flax.core.variables</span></code></a> for more explanation on variables and collections.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>col</strong> – The variable collection name.</p></li> <li><p><strong>name</strong> – The name of the variable.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>True if the variable exists.</p> </dd> </dl> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.Module.init"> <span class="sig-name descname"><span class="pre">init</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">rngs</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">method</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">mutable</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">DenyList(deny='intermediates')</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">capture_intermediates</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/module.html#Module.init"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.Module.init" title="Permalink to this definition">#</a></dt> <dd><p>Initializes a module method with variables and returns modified variables.</p> <p><code class="docutils literal notranslate"><span class="pre">init</span></code> takes as first argument either a single <code class="docutils literal notranslate"><span class="pre">PRNGKey</span></code>, or a dictionary mapping variable collections names to their <code class="docutils literal notranslate"><span class="pre">PRNGKeys</span></code>, and will call <code class="docutils literal notranslate"><span class="pre">method</span></code> (which is the module’s <code class="docutils literal notranslate"><span class="pre">__call__</span></code> function by default) passing <code class="docutils literal notranslate"><span class="pre">*args</span></code> and <code class="docutils literal notranslate"><span class="pre">**kwargs</span></code>, and returns a dictionary of initialized variables.</p> <p>Example:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="kn">import</span> <span class="nn">flax.linen</span> <span class="k">as</span> <span class="nn">nn</span> <span class="gp">>>> </span><span class="kn">import</span> <span class="nn">jax</span><span class="o">,</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span> <span class="gp">>>> </span><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> <span class="gp">>>> </span><span class="k">class</span> <span class="nc">Foo</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="gp">... </span> <span class="nd">@nn</span><span class="o">.</span><span class="n">compact</span> <span class="gp">... </span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">train</span><span class="p">):</span> <span class="gp">... </span> <span class="n">x</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">16</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="gp">... </span> <span class="n">x</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm</span><span class="p">(</span><span class="n">use_running_average</span><span class="o">=</span><span class="ow">not</span> <span class="n">train</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="gp">... </span> <span class="n">x</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="gp">... </span> <span class="k">return</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">1</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">x</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">7</span><span class="p">))</span> <span class="gp">>>> </span><span class="n">module</span> <span class="o">=</span> <span class="n">Foo</span><span class="p">()</span> <span class="gp">>>> </span><span class="n">key</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">variables</span> <span class="o">=</span> <span class="n">module</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> </pre></div> </div> <p>If you pass a single <code class="docutils literal notranslate"><span class="pre">PRNGKey</span></code>, Flax will use it to feed the <code class="docutils literal notranslate"><span class="pre">'params'</span></code> RNG stream. If you want to use a different RNG stream or need to use multiple streams, you can pass a dictionary mapping each RNG stream name to its corresponding <code class="docutils literal notranslate"><span class="pre">PRNGKey</span></code> to <code class="docutils literal notranslate"><span class="pre">init</span></code>. If <code class="docutils literal notranslate"><span class="pre">self.make_rng(name)</span></code> is called on an RNG stream name that isn’t passed by the user, it will default to using the <code class="docutils literal notranslate"><span class="pre">'params'</span></code> RNG stream.</p> <p>Example:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="k">class</span> <span class="nc">Foo</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="gp">... </span> <span class="nd">@nn</span><span class="o">.</span><span class="n">compact</span> <span class="gp">... </span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="gp">... </span> <span class="n">x</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">16</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="gp">... </span> <span class="n">x</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="gp">...</span> <span class="gp">... </span> <span class="n">other_variable</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">variable</span><span class="p">(</span> <span class="gp">... </span> <span class="s1">'other_collection'</span><span class="p">,</span> <span class="gp">... </span> <span class="s1">'other_variable'</span><span class="p">,</span> <span class="gp">... </span> <span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">make_rng</span><span class="p">(</span><span class="s1">'other_rng'</span><span class="p">),</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">),</span> <span class="gp">... </span> <span class="n">x</span><span class="p">,</span> <span class="gp">... </span> <span class="p">)</span> <span class="gp">... </span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">other_variable</span><span class="o">.</span><span class="n">value</span> <span class="gp">...</span> <span class="gp">... </span> <span class="k">return</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">1</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">module</span> <span class="o">=</span> <span class="n">Foo</span><span class="p">()</span> <span class="gp">>>> </span><span class="n">rngs</span> <span class="o">=</span> <span class="p">{</span><span class="s1">'params'</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="s1">'other_rng'</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">1</span><span class="p">)}</span> <span class="gp">>>> </span><span class="n">variables0</span> <span class="o">=</span> <span class="n">module</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">rngs</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">rngs</span><span class="p">[</span><span class="s1">'other_rng'</span><span class="p">]</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">variables1</span> <span class="o">=</span> <span class="n">module</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">rngs</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="c1"># equivalent params (key(0))</span> <span class="gp">>>> </span><span class="n">_</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">tree_util</span><span class="o">.</span><span class="n">tree_map</span><span class="p">(</span> <span class="gp">... </span> <span class="n">np</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_allclose</span><span class="p">,</span> <span class="n">variables0</span><span class="p">[</span><span class="s1">'params'</span><span class="p">],</span> <span class="n">variables1</span><span class="p">[</span><span class="s1">'params'</span><span class="p">]</span> <span class="gp">... </span><span class="p">)</span> <span class="gp">>>> </span><span class="c1"># different other_variable (key(1) vs key(0))</span> <span class="gp">>>> </span><span class="n">np</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_raises</span><span class="p">(</span> <span class="gp">... </span> <span class="ne">AssertionError</span><span class="p">,</span> <span class="gp">... </span> <span class="n">np</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_allclose</span><span class="p">,</span> <span class="gp">... </span> <span class="n">variables0</span><span class="p">[</span><span class="s1">'other_collection'</span><span class="p">][</span><span class="s1">'other_variable'</span><span class="p">],</span> <span class="gp">... </span> <span class="n">variables1</span><span class="p">[</span><span class="s1">'other_collection'</span><span class="p">][</span><span class="s1">'other_variable'</span><span class="p">],</span> <span class="gp">... </span><span class="p">)</span> <span class="gp">>>> </span><span class="k">del</span> <span class="n">rngs</span><span class="p">[</span><span class="s1">'other_rng'</span><span class="p">]</span> <span class="gp">>>> </span><span class="c1"># self.make_rng('other_rng') will default to using the 'params' RNG stream</span> <span class="gp">>>> </span><span class="n">variables2</span> <span class="o">=</span> <span class="n">module</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">rngs</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="c1"># equivalent params (key(0))</span> <span class="gp">>>> </span><span class="n">_</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">tree_util</span><span class="o">.</span><span class="n">tree_map</span><span class="p">(</span> <span class="gp">... </span> <span class="n">np</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_allclose</span><span class="p">,</span> <span class="n">variables1</span><span class="p">[</span><span class="s1">'params'</span><span class="p">],</span> <span class="n">variables2</span><span class="p">[</span><span class="s1">'params'</span><span class="p">]</span> <span class="gp">... </span><span class="p">)</span> <span class="gp">>>> </span><span class="c1"># equivalent other_variable (key(0))</span> <span class="gp">>>> </span><span class="n">np</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_allclose</span><span class="p">(</span> <span class="gp">... </span> <span class="n">variables1</span><span class="p">[</span><span class="s1">'other_collection'</span><span class="p">][</span><span class="s1">'other_variable'</span><span class="p">],</span> <span class="gp">... </span> <span class="n">variables2</span><span class="p">[</span><span class="s1">'other_collection'</span><span class="p">][</span><span class="s1">'other_variable'</span><span class="p">],</span> <span class="gp">... </span><span class="p">)</span> <span class="gp">>>> </span><span class="c1"># passing in a single key is equivalent to passing in {'params': key}</span> <span class="gp">>>> </span><span class="n">variables3</span> <span class="o">=</span> <span class="n">module</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="c1"># equivalent params (key(0))</span> <span class="gp">>>> </span><span class="n">_</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">tree_util</span><span class="o">.</span><span class="n">tree_map</span><span class="p">(</span> <span class="gp">... </span> <span class="n">np</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_allclose</span><span class="p">,</span> <span class="n">variables2</span><span class="p">[</span><span class="s1">'params'</span><span class="p">],</span> <span class="n">variables3</span><span class="p">[</span><span class="s1">'params'</span><span class="p">]</span> <span class="gp">... </span><span class="p">)</span> <span class="gp">>>> </span><span class="c1"># equivalent other_variable (key(0))</span> <span class="gp">>>> </span><span class="n">np</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_allclose</span><span class="p">(</span> <span class="gp">... </span> <span class="n">variables2</span><span class="p">[</span><span class="s1">'other_collection'</span><span class="p">][</span><span class="s1">'other_variable'</span><span class="p">],</span> <span class="gp">... </span> <span class="n">variables3</span><span class="p">[</span><span class="s1">'other_collection'</span><span class="p">][</span><span class="s1">'other_variable'</span><span class="p">],</span> <span class="gp">... </span><span class="p">)</span> </pre></div> </div> <p>Jitting <code class="docutils literal notranslate"><span class="pre">init</span></code> initializes a model lazily using only the shapes of the provided arguments, and avoids computing the forward pass with actual values. Example:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="n">module</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">init_jit</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">jit</span><span class="p">(</span><span class="n">module</span><span class="o">.</span><span class="n">init</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">variables</span> <span class="o">=</span> <span class="n">init_jit</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">x</span><span class="p">)</span> </pre></div> </div> <p><code class="docutils literal notranslate"><span class="pre">init</span></code> is a light wrapper over <code class="docutils literal notranslate"><span class="pre">apply</span></code>, so other <code class="docutils literal notranslate"><span class="pre">apply</span></code> arguments like <code class="docutils literal notranslate"><span class="pre">method</span></code>, <code class="docutils literal notranslate"><span class="pre">mutable</span></code>, and <code class="docutils literal notranslate"><span class="pre">capture_intermediates</span></code> are also available.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>rngs</strong> – The rngs for the variable collections.</p></li> <li><p><strong>*args</strong> – Named arguments passed to the init function.</p></li> <li><p><strong>method</strong> – An optional method. If provided, applies this method. If not provided, applies the <code class="docutils literal notranslate"><span class="pre">__call__</span></code> method. A string can also be provided to specify a method by name.</p></li> <li><p><strong>mutable</strong> – Can be bool, str, or list. Specifies which collections should be treated as mutable: <code class="docutils literal notranslate"><span class="pre">bool</span></code>: all/no collections are mutable. <code class="docutils literal notranslate"><span class="pre">str</span></code>: The name of a single mutable collection. <code class="docutils literal notranslate"><span class="pre">list</span></code>: A list of names of mutable collections. By default all collections except “intermediates” are mutable.</p></li> <li><p><strong>capture_intermediates</strong> – If <code class="docutils literal notranslate"><span class="pre">True</span></code>, captures intermediate return values of all Modules inside the “intermediates” collection. By default only the return values of all <code class="docutils literal notranslate"><span class="pre">__call__</span></code> methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.</p></li> <li><p><strong>**kwargs</strong> – Keyword arguments passed to the init function.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>The initialized variable dict.</p> </dd> </dl> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.Module.init_with_output"> <span class="sig-name descname"><span class="pre">init_with_output</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">rngs</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">method</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">mutable</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">DenyList(deny='intermediates')</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">capture_intermediates</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/module.html#Module.init_with_output"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.Module.init_with_output" title="Permalink to this definition">#</a></dt> <dd><p>Initializes a module method with variables and returns output and modified variables.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>rngs</strong> – The rngs for the variable collections.</p></li> <li><p><strong>*args</strong> – Named arguments passed to the init function.</p></li> <li><p><strong>method</strong> – An optional method. If provided, applies this method. If not provided, applies the <code class="docutils literal notranslate"><span class="pre">__call__</span></code> method. A string can also be provided to specify a method by name.</p></li> <li><p><strong>mutable</strong> – Can be bool, str, or list. Specifies which collections should be treated as mutable: <code class="docutils literal notranslate"><span class="pre">bool</span></code>: all/no collections are mutable. <code class="docutils literal notranslate"><span class="pre">str</span></code>: The name of a single mutable collection. <code class="docutils literal notranslate"><span class="pre">list</span></code>: A list of names of mutable collections. By default, all collections except “intermediates” are mutable.</p></li> <li><p><strong>capture_intermediates</strong> – If <code class="docutils literal notranslate"><span class="pre">True</span></code>, captures intermediate return values of all Modules inside the “intermediates” collection. By default only the return values of all <code class="docutils literal notranslate"><span class="pre">__call__</span></code> methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.</p></li> <li><p><strong>**kwargs</strong> – Keyword arguments passed to the init function.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p><code class="docutils literal notranslate"><span class="pre">(output,</span> <span class="pre">vars)</span></code>, where <code class="docutils literal notranslate"><span class="pre">vars</span></code> are is a dict of the modified collections.</p> </dd> </dl> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.Module.is_initializing"> <span class="sig-name descname"><span class="pre">is_initializing</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/module.html#Module.is_initializing"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.Module.is_initializing" title="Permalink to this definition">#</a></dt> <dd><p>Returns True if running under self.init(…) or nn.init(…)().</p> <p>This is a helper method to handle the common case of simple initialization where we wish to have setup logic occur when only called under <code class="docutils literal notranslate"><span class="pre">module.init</span></code> or <code class="docutils literal notranslate"><span class="pre">nn.init</span></code>. For more complicated multi-phase initialization scenarios it is better to test for the mutability of particular variable collections or for the presence of particular variables that potentially need to be initialized.</p> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.Module.is_mutable_collection"> <span class="sig-name descname"><span class="pre">is_mutable_collection</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">col</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/module.html#Module.is_mutable_collection"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.Module.is_mutable_collection" title="Permalink to this definition">#</a></dt> <dd><p>Returns true if the collection <code class="docutils literal notranslate"><span class="pre">col</span></code> is mutable.</p> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.Module.lazy_init"> <span class="sig-name descname"><span class="pre">lazy_init</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">rngs</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">method</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">mutable</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">DenyList(deny='intermediates')</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/module.html#Module.lazy_init"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.Module.lazy_init" title="Permalink to this definition">#</a></dt> <dd><p>Initializes a module without computing on an actual input.</p> <p>lazy_init will initialize the variables without doing unnecessary compute. The input data should be passed as a <code class="docutils literal notranslate"><span class="pre">jax.ShapeDtypeStruct</span></code> which specifies the shape and dtype of the input but no concrete data.</p> <p>Example:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="n">model</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">features</span><span class="o">=</span><span class="mi">256</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">variables</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">lazy_init</span><span class="p">(</span> <span class="gp">... </span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">jax</span><span class="o">.</span><span class="n">ShapeDtypeStruct</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">128</span><span class="p">),</span> <span class="n">jnp</span><span class="o">.</span><span class="n">float32</span><span class="p">))</span> </pre></div> </div> <p>The args and kwargs args passed to <code class="docutils literal notranslate"><span class="pre">lazy_init</span></code> can be a mix of concrete (jax arrays, scalars, bools) and abstract (ShapeDtypeStruct) values. Concrete values are only necessary for arguments that affect the initialization of variables. For example, the model might expect a keyword arg that enables/disables a subpart of the model. In this case, an explicit value (True/Flase) should be passed otherwise <code class="docutils literal notranslate"><span class="pre">lazy_init</span></code> cannot infer which variables should be initialized.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>rngs</strong> – The rngs for the variable collections.</p></li> <li><p><strong>*args</strong> – arguments passed to the init function.</p></li> <li><p><strong>method</strong> – An optional method. If provided, applies this method. If not provided, applies the <code class="docutils literal notranslate"><span class="pre">__call__</span></code> method.</p></li> <li><p><strong>mutable</strong> – Can be bool, str, or list. Specifies which collections should be treated as mutable: <code class="docutils literal notranslate"><span class="pre">bool</span></code>: all/no collections are mutable. <code class="docutils literal notranslate"><span class="pre">str</span></code>: The name of a single mutable collection. <code class="docutils literal notranslate"><span class="pre">list</span></code>: A list of names of mutable collections. By default all collections except “intermediates” are mutable.</p></li> <li><p><strong>**kwargs</strong> – Keyword arguments passed to the init function.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>The initialized variable dict.</p> </dd> </dl> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.Module.make_rng"> <span class="sig-name descname"><span class="pre">make_rng</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">name</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">'params'</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/module.html#Module.make_rng"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.Module.make_rng" title="Permalink to this definition">#</a></dt> <dd><p>Returns a new RNG key from a given RNG sequence for this Module.</p> <p>The new RNG key is split from the previous one. Thus, every call to <code class="docutils literal notranslate"><span class="pre">make_rng</span></code> returns a new RNG key, while still guaranteeing full reproducibility.</p> <div class="admonition note"> <p class="admonition-title">Note</p> <p>If an invalid name is passed (i.e. no RNG key was passed by the user in <code class="docutils literal notranslate"><span class="pre">.init</span></code> or <code class="docutils literal notranslate"><span class="pre">.apply</span></code> for this name), then <code class="docutils literal notranslate"><span class="pre">name</span></code> will default to <code class="docutils literal notranslate"><span class="pre">'params'</span></code>.</p> </div> <p>Example:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="kn">import</span> <span class="nn">jax</span> <span class="gp">>>> </span><span class="kn">import</span> <span class="nn">flax.linen</span> <span class="k">as</span> <span class="nn">nn</span> <span class="gp">>>> </span><span class="k">class</span> <span class="nc">ParamsModule</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="gp">... </span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="gp">... </span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">make_rng</span><span class="p">(</span><span class="s1">'params'</span><span class="p">)</span> <span class="gp">>>> </span><span class="k">class</span> <span class="nc">OtherModule</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="gp">... </span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="gp">... </span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">make_rng</span><span class="p">(</span><span class="s1">'other'</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">key</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">params_out</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">ParamsModule</span><span class="p">()</span><span class="o">.</span><span class="n">init_with_output</span><span class="p">({</span><span class="s1">'params'</span><span class="p">:</span> <span class="n">key</span><span class="p">})</span> <span class="gp">>>> </span><span class="c1"># self.make_rng('other') will default to using the 'params' RNG stream</span> <span class="gp">>>> </span><span class="n">other_out</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">OtherModule</span><span class="p">()</span><span class="o">.</span><span class="n">init_with_output</span><span class="p">({</span><span class="s1">'params'</span><span class="p">:</span> <span class="n">key</span><span class="p">})</span> <span class="gp">>>> </span><span class="k">assert</span> <span class="n">params_out</span> <span class="o">==</span> <span class="n">other_out</span> </pre></div> </div> <p>Learn more about RNG’s by reading the Flax RNG guide: <a class="reference external" href="https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/rng_guide.html">https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/rng_guide.html</a></p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><p><strong>name</strong> – The RNG sequence name.</p> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>The newly generated RNG key.</p> </dd> </dl> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.Module.module_paths"> <span class="sig-name descname"><span class="pre">module_paths</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">rngs</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">show_repeated</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">mutable</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">DenyList(deny='intermediates')</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/module.html#Module.module_paths"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.Module.module_paths" title="Permalink to this definition">#</a></dt> <dd><p>Returns a dictionary mapping module paths to module instances.</p> <p>This method has the same signature and internally calls <code class="docutils literal notranslate"><span class="pre">Module.init</span></code>, but instead of returning the variables, it returns a dictionary mapping module paths to unbounded copies of module instances that were used at runtime. <code class="docutils literal notranslate"><span class="pre">module_paths</span></code> uses <code class="docutils literal notranslate"><span class="pre">jax.eval_shape</span></code> to run the forward computation without consuming any FLOPs or allocating memory.</p> <p>Example:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="kn">import</span> <span class="nn">flax.linen</span> <span class="k">as</span> <span class="nn">nn</span> <span class="gp">>>> </span><span class="kn">import</span> <span class="nn">jax</span><span class="o">,</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span> <span class="gp">>>> </span><span class="k">class</span> <span class="nc">Foo</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="gp">... </span> <span class="nd">@nn</span><span class="o">.</span><span class="n">compact</span> <span class="gp">... </span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="gp">... </span> <span class="n">h</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">4</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="gp">... </span> <span class="k">return</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">2</span><span class="p">)(</span><span class="n">h</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">x</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">16</span><span class="p">,</span> <span class="mi">9</span><span class="p">))</span> <span class="gp">>>> </span><span class="n">modules</span> <span class="o">=</span> <span class="n">Foo</span><span class="p">()</span><span class="o">.</span><span class="n">module_paths</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="nb">print</span><span class="p">({</span> <span class="gp">... </span> <span class="n">p</span><span class="p">:</span> <span class="nb">type</span><span class="p">(</span><span class="n">m</span><span class="p">)</span><span class="o">.</span><span class="vm">__name__</span> <span class="k">for</span> <span class="n">p</span><span class="p">,</span> <span class="n">m</span> <span class="ow">in</span> <span class="n">modules</span><span class="o">.</span><span class="n">items</span><span class="p">()</span> <span class="gp">... </span><span class="p">})</span> <span class="go">{'': 'Foo', 'Dense_0': 'Dense', 'Dense_1': 'Dense'}</span> </pre></div> </div> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>rngs</strong> – The rngs for the variable collections as passed to <code class="docutils literal notranslate"><span class="pre">Module.init</span></code>.</p></li> <li><p><strong>*args</strong> – The arguments to the forward computation.</p></li> <li><p><strong>show_repeated</strong> – If <code class="docutils literal notranslate"><span class="pre">True</span></code>, repeated calls to the same module will be shown in the table, otherwise only the first call will be shown. Default is <code class="docutils literal notranslate"><span class="pre">False</span></code>.</p></li> <li><p><strong>mutable</strong> – Can be bool, str, or list. Specifies which collections should be treated as mutable: <code class="docutils literal notranslate"><span class="pre">bool</span></code>: all/no collections are mutable. <code class="docutils literal notranslate"><span class="pre">str</span></code>: The name of a single mutable collection. <code class="docutils literal notranslate"><span class="pre">list</span></code>: A list of names of mutable collections. By default, all collections except ‘intermediates’ are mutable.</p></li> <li><p><strong>**kwargs</strong> – keyword arguments to pass to the forward computation.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>A dict`ionary mapping module paths to module instances.</p> </dd> </dl> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.Module.param"> <span class="sig-name descname"><span class="pre">param</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">name</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">init_fn</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">init_args</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">unbox</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">True</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">init_kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/module.html#Module.param"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.Module.param" title="Permalink to this definition">#</a></dt> <dd><p>Declares and returns a parameter in this Module.</p> <p>Parameters are read-only variables in the collection named “params”. See <a class="reference internal" href="variable.html#module-flax.core.variables" title="flax.core.variables"><code class="xref py py-mod docutils literal notranslate"><span class="pre">flax.core.variables</span></code></a> for more details on variables.</p> <p>The first argument of <code class="docutils literal notranslate"><span class="pre">init_fn</span></code> is assumed to be a PRNG key, which is provided automatically and does not have to be passed using <code class="docutils literal notranslate"><span class="pre">init_args</span></code> or <code class="docutils literal notranslate"><span class="pre">init_kwargs</span></code>:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="k">class</span> <span class="nc">Foo</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="gp">... </span> <span class="nd">@nn</span><span class="o">.</span><span class="n">compact</span> <span class="gp">... </span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="gp">... </span> <span class="n">x</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">4</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="gp">... </span> <span class="n">mean</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">param</span><span class="p">(</span><span class="s1">'mean'</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">initializers</span><span class="o">.</span><span class="n">lecun_normal</span><span class="p">(),</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="gp">... </span> <span class="o">...</span> <span class="gp">... </span> <span class="k">return</span> <span class="n">x</span> <span class="o">*</span> <span class="n">mean</span> <span class="gp">>>> </span><span class="n">variables</span> <span class="o">=</span> <span class="n">Foo</span><span class="p">()</span><span class="o">.</span><span class="n">init</span><span class="p">({</span><span class="s1">'params'</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="s1">'stats'</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">1</span><span class="p">)},</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">)))</span> <span class="gp">>>> </span><span class="n">jax</span><span class="o">.</span><span class="n">tree_util</span><span class="o">.</span><span class="n">tree_map</span><span class="p">(</span><span class="n">jnp</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">variables</span><span class="p">)</span> <span class="go">{'params': {'Dense_0': {'bias': (4,), 'kernel': (3, 4)}, 'mean': (2, 4)}}</span> </pre></div> </div> <p>In the example above, the function <code class="docutils literal notranslate"><span class="pre">lecun_normal</span></code> expects two arguments: <code class="docutils literal notranslate"><span class="pre">key</span></code> and <code class="docutils literal notranslate"><span class="pre">shape</span></code>, but only <code class="docutils literal notranslate"><span class="pre">shape</span></code> has to be provided explicitly; <code class="docutils literal notranslate"><span class="pre">key</span></code> is set automatically using the PRNG for <code class="docutils literal notranslate"><span class="pre">params</span></code> that is passed when initializing the module using <a class="reference internal" href="#flax.linen.init" title="flax.linen.init"><code class="xref py py-meth docutils literal notranslate"><span class="pre">init()</span></code></a>.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>name</strong> – The parameter name.</p></li> <li><p><strong>init_fn</strong> – The function that will be called to compute the initial value of this variable. This function will only be called the first time this parameter is used in this module.</p></li> <li><p><strong>*init_args</strong> – The positional arguments to pass to init_fn.</p></li> <li><p><strong>unbox</strong> – If True, <code class="docutils literal notranslate"><span class="pre">AxisMetadata</span></code> instances are replaced by their unboxed value, see <code class="docutils literal notranslate"><span class="pre">flax.nn.meta.unbox</span></code> (default: True).</p></li> <li><p><strong>**init_kwargs</strong> – The key-word arguments to pass to init_fn.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>The value of the initialized parameter. Throws an error if the parameter exists already.</p> </dd> </dl> </dd></dl> <dl class="py property"> <dt class="sig sig-object py" id="flax.linen.Module.path"> <em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">path</span></span><a class="headerlink" href="#flax.linen.Module.path" title="Permalink to this definition">#</a></dt> <dd><p>Get the path of this Module. Top-level root modules have an empty path <code class="docutils literal notranslate"><span class="pre">()</span></code>. Note that this method can only be used on bound modules that have a valid scope.</p> <p>Example usage:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="kn">import</span> <span class="nn">flax.linen</span> <span class="k">as</span> <span class="nn">nn</span> <span class="gp">>>> </span><span class="kn">import</span> <span class="nn">jax</span><span class="o">,</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span> <span class="gp">>>> </span><span class="k">class</span> <span class="nc">SubModel</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="gp">... </span> <span class="nd">@nn</span><span class="o">.</span><span class="n">compact</span> <span class="gp">... </span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="gp">... </span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s1">'SubModel path: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">path</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span> <span class="gp">... </span> <span class="k">return</span> <span class="n">x</span> <span class="gp">>>> </span><span class="k">class</span> <span class="nc">Model</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="gp">... </span> <span class="nd">@nn</span><span class="o">.</span><span class="n">compact</span> <span class="gp">... </span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="gp">... </span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s1">'Model path: </span><span class="si">{</span><span class="bp">self</span><span class="o">.</span><span class="n">path</span><span class="si">}</span><span class="s1">'</span><span class="p">)</span> <span class="gp">... </span> <span class="k">return</span> <span class="n">SubModel</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">model</span> <span class="o">=</span> <span class="n">Model</span><span class="p">()</span> <span class="gp">>>> </span><span class="n">variables</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)))</span> <span class="go">Model path: ()</span> <span class="go">SubModel path: ('SubModel_0',)</span> </pre></div> </div> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.Module.perturb"> <span class="sig-name descname"><span class="pre">perturb</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">name</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">value</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">collection</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">'perturbations'</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/module.html#Module.perturb"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.Module.perturb" title="Permalink to this definition">#</a></dt> <dd><p>Add an zero-value variable (‘perturbation’) to the intermediate value.</p> <p>The gradient of <code class="docutils literal notranslate"><span class="pre">value</span></code> would be the same as the gradient of this perturbation variable. Therefore, if you define your loss function with both params and perturbations as standalone arguments, you can get the intermediate gradients of <code class="docutils literal notranslate"><span class="pre">value</span></code> by running <code class="docutils literal notranslate"><span class="pre">jax.grad</span></code> on the perturbation argument.</p> <div class="admonition note"> <p class="admonition-title">Note</p> <p>This is an experimental API and may be tweaked later for better performance and usability. At its current stage, it creates extra dummy variables that occupies extra memory space. Use it only to debug gradients in training.</p> </div> <p>Example:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="k">class</span> <span class="nc">Foo</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="gp">... </span> <span class="nd">@nn</span><span class="o">.</span><span class="n">compact</span> <span class="gp">... </span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="gp">... </span> <span class="n">x</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">3</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="gp">... </span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">perturb</span><span class="p">(</span><span class="s1">'dense3'</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="gp">... </span> <span class="k">return</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">2</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="k">def</span> <span class="nf">loss</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span><span class="p">):</span> <span class="gp">... </span> <span class="n">preds</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">inputs</span><span class="p">)</span> <span class="gp">... </span> <span class="k">return</span> <span class="n">jnp</span><span class="o">.</span><span class="n">square</span><span class="p">(</span><span class="n">preds</span> <span class="o">-</span> <span class="n">targets</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span> <span class="gp">>>> </span><span class="n">x</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">9</span><span class="p">))</span> <span class="gp">>>> </span><span class="n">y</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">))</span> <span class="gp">>>> </span><span class="n">model</span> <span class="o">=</span> <span class="n">Foo</span><span class="p">()</span> <span class="gp">>>> </span><span class="n">variables</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">intm_grads</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">grad</span><span class="p">(</span><span class="n">loss</span><span class="p">,</span> <span class="n">argnums</span><span class="o">=</span><span class="mi">0</span><span class="p">)(</span><span class="n">variables</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span> <span class="gp">>>> </span><span class="nb">print</span><span class="p">(</span><span class="n">intm_grads</span><span class="p">[</span><span class="s1">'perturbations'</span><span class="p">][</span><span class="s1">'dense3'</span><span class="p">])</span> <span class="go">[[-1.456924 -0.44332537 0.02422847]</span> <span class="go"> [-1.456924 -0.44332537 0.02422847]]</span> </pre></div> </div> <p>If perturbations are not passed to <code class="docutils literal notranslate"><span class="pre">apply</span></code>, <code class="docutils literal notranslate"><span class="pre">perturb</span></code> behaves like a no-op so you can easily disable the behavior when not needed:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="n">model</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="c1"># works as expected</span> <span class="go">Array([[-1.0980128 , -0.67961735],</span> <span class="go"> [-1.0980128 , -0.67961735]], dtype=float32)</span> <span class="gp">>>> </span><span class="n">model</span><span class="o">.</span><span class="n">apply</span><span class="p">({</span><span class="s1">'params'</span><span class="p">:</span> <span class="n">variables</span><span class="p">[</span><span class="s1">'params'</span><span class="p">]},</span> <span class="n">x</span><span class="p">)</span> <span class="c1"># behaves like a no-op</span> <span class="go">Array([[-1.0980128 , -0.67961735],</span> <span class="go"> [-1.0980128 , -0.67961735]], dtype=float32)</span> <span class="gp">>>> </span><span class="n">intm_grads</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">grad</span><span class="p">(</span><span class="n">loss</span><span class="p">,</span> <span class="n">argnums</span><span class="o">=</span><span class="mi">0</span><span class="p">)({</span><span class="s1">'params'</span><span class="p">:</span> <span class="n">variables</span><span class="p">[</span><span class="s1">'params'</span><span class="p">]},</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span> <span class="gp">>>> </span><span class="s1">'perturbations'</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">intm_grads</span> <span class="go">True</span> </pre></div> </div> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.Module.put_variable"> <span class="sig-name descname"><span class="pre">put_variable</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">col</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">name</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">value</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/module.html#Module.put_variable"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.Module.put_variable" title="Permalink to this definition">#</a></dt> <dd><p>Updates the value of the given variable if it is mutable, or an error otherwise.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>col</strong> – the variable collection.</p></li> <li><p><strong>name</strong> – the name of the variable.</p></li> <li><p><strong>value</strong> – the new value of the variable.</p></li> </ul> </dd> </dl> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.Module.setup"> <span class="sig-name descname"><span class="pre">setup</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/module.html#Module.setup"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.Module.setup" title="Permalink to this definition">#</a></dt> <dd><p>Initializes a Module lazily (similar to a lazy <code class="docutils literal notranslate"><span class="pre">__init__</span></code>).</p> <p><code class="docutils literal notranslate"><span class="pre">setup</span></code> is called once lazily on a module instance when a module is bound, immediately before any other methods like <code class="docutils literal notranslate"><span class="pre">__call__</span></code> are invoked, or before a <code class="docutils literal notranslate"><span class="pre">setup</span></code>-defined attribute on <code class="docutils literal notranslate"><span class="pre">self</span></code> is accessed.</p> <p>This can happen in three cases:</p> <blockquote> <div><ol class="arabic"> <li><p>Immediately when invoking <a class="reference internal" href="#flax.linen.apply" title="flax.linen.apply"><code class="xref py py-meth docutils literal notranslate"><span class="pre">apply()</span></code></a>, <a class="reference internal" href="#flax.linen.init" title="flax.linen.init"><code class="xref py py-meth docutils literal notranslate"><span class="pre">init()</span></code></a> or <code class="xref py py-meth docutils literal notranslate"><span class="pre">init_and_output()</span></code>.</p></li> <li><p>Once the module is given a name by being assigned to an attribute of another module inside the other module’s <code class="docutils literal notranslate"><span class="pre">setup</span></code> method (see <a class="reference internal" href="#flax.linen.Module.__setattr__" title="flax.linen.Module.__setattr__"><code class="xref py py-meth docutils literal notranslate"><span class="pre">__setattr__()</span></code></a>):</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="k">class</span> <span class="nc">MyModule</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="gp">... </span> <span class="k">def</span> <span class="nf">setup</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="gp">... </span> <span class="n">submodule</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv</span><span class="p">(</span><span class="o">...</span><span class="p">)</span> <span class="go">... # Accessing `submodule` attributes does not yet work here.</span> <span class="go">... # The following line invokes `self.__setattr__`, which gives</span> <span class="go">... # `submodule` the name "conv1".</span> <span class="go">... self.conv1 = submodule</span> <span class="go">... # Accessing `submodule` attributes or methods is now safe and</span> <span class="go">... # either causes setup() to be called once.</span> </pre></div> </div> </li> <li><p>Once a module is constructed inside a method wrapped with <a class="reference internal" href="decorators.html#flax.linen.compact" title="flax.linen.compact"><code class="xref py py-meth docutils literal notranslate"><span class="pre">compact()</span></code></a>, immediately before another method is called or <code class="docutils literal notranslate"><span class="pre">setup</span></code> defined attribute is accessed.</p></li> </ol> </div></blockquote> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.Module.sow"> <span class="sig-name descname"><span class="pre">sow</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">col</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">name</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">value</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">reduce_fn=<function</span> <span class="pre"><lambda>></span></span></em>, <em class="sig-param"><span class="n"><span class="pre">init_fn=<function</span> <span class="pre"><lambda>></span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/module.html#Module.sow"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.Module.sow" title="Permalink to this definition">#</a></dt> <dd><p>Stores a value in a collection.</p> <p>Collections can be used to collect intermediate values without the overhead of explicitly passing a container through each Module call.</p> <p>If the target collection is not mutable <code class="docutils literal notranslate"><span class="pre">sow</span></code> behaves like a no-op and returns <code class="docutils literal notranslate"><span class="pre">False</span></code>.</p> <p>Example:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="kn">import</span> <span class="nn">jax</span> <span class="gp">>>> </span><span class="kn">import</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span> <span class="gp">>>> </span><span class="kn">import</span> <span class="nn">flax.linen</span> <span class="k">as</span> <span class="nn">nn</span> <span class="gp">>>> </span><span class="k">class</span> <span class="nc">Foo</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="gp">... </span> <span class="nd">@nn</span><span class="o">.</span><span class="n">compact</span> <span class="gp">... </span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="gp">... </span> <span class="n">h</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">4</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="gp">... </span> <span class="bp">self</span><span class="o">.</span><span class="n">sow</span><span class="p">(</span><span class="s1">'intermediates'</span><span class="p">,</span> <span class="s1">'h'</span><span class="p">,</span> <span class="n">h</span><span class="p">)</span> <span class="gp">... </span> <span class="k">return</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">2</span><span class="p">)(</span><span class="n">h</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">x</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">16</span><span class="p">,</span> <span class="mi">9</span><span class="p">))</span> <span class="gp">>>> </span><span class="n">model</span> <span class="o">=</span> <span class="n">Foo</span><span class="p">()</span> <span class="gp">>>> </span><span class="n">variables</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">y</span><span class="p">,</span> <span class="n">state</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">mutable</span><span class="o">=</span><span class="p">[</span><span class="s1">'intermediates'</span><span class="p">])</span> <span class="gp">>>> </span><span class="n">jax</span><span class="o">.</span><span class="n">tree</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">jnp</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">state</span><span class="p">[</span><span class="s1">'intermediates'</span><span class="p">])</span> <span class="go">{'h': ((16, 4),)}</span> </pre></div> </div> <p>By default the values are stored in a tuple and each stored value is appended at the end. This way all intermediates can be tracked when the same module is called multiple times. Alternatively, a custom init/reduce function can be passed:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="k">class</span> <span class="nc">Foo2</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="gp">... </span> <span class="nd">@nn</span><span class="o">.</span><span class="n">compact</span> <span class="gp">... </span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="gp">... </span> <span class="n">init_fn</span> <span class="o">=</span> <span class="k">lambda</span><span class="p">:</span> <span class="mi">0</span> <span class="gp">... </span> <span class="n">reduce_fn</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">:</span> <span class="n">a</span> <span class="o">+</span> <span class="n">b</span> <span class="gp">... </span> <span class="bp">self</span><span class="o">.</span><span class="n">sow</span><span class="p">(</span><span class="s1">'intermediates'</span><span class="p">,</span> <span class="s1">'h'</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="gp">... </span> <span class="n">init_fn</span><span class="o">=</span><span class="n">init_fn</span><span class="p">,</span> <span class="n">reduce_fn</span><span class="o">=</span><span class="n">reduce_fn</span><span class="p">)</span> <span class="gp">... </span> <span class="bp">self</span><span class="o">.</span><span class="n">sow</span><span class="p">(</span><span class="s1">'intermediates'</span><span class="p">,</span> <span class="s1">'h'</span><span class="p">,</span> <span class="n">x</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span> <span class="gp">... </span> <span class="n">init_fn</span><span class="o">=</span><span class="n">init_fn</span><span class="p">,</span> <span class="n">reduce_fn</span><span class="o">=</span><span class="n">reduce_fn</span><span class="p">)</span> <span class="gp">... </span> <span class="k">return</span> <span class="n">x</span> <span class="gp">>>> </span><span class="n">x</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="gp">>>> </span><span class="n">model</span> <span class="o">=</span> <span class="n">Foo2</span><span class="p">()</span> <span class="gp">>>> </span><span class="n">variables</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">y</span><span class="p">,</span> <span class="n">state</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span> <span class="gp">... </span> <span class="n">variables</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">mutable</span><span class="o">=</span><span class="p">[</span><span class="s1">'intermediates'</span><span class="p">])</span> <span class="gp">>>> </span><span class="nb">print</span><span class="p">(</span><span class="n">state</span><span class="p">[</span><span class="s1">'intermediates'</span><span class="p">])</span> <span class="go">{'h': Array([[3.]], dtype=float32)}</span> </pre></div> </div> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>col</strong> – The name of the variable collection.</p></li> <li><p><strong>name</strong> – The name of the variable.</p></li> <li><p><strong>value</strong> – The value of the variable.</p></li> <li><p><strong>reduce_fn</strong> – The function used to combine the existing value with the new value. The default is to append the value to a tuple.</p></li> <li><p><strong>init_fn</strong> – For the first value stored, <code class="docutils literal notranslate"><span class="pre">reduce_fn</span></code> will be passed the result of <code class="docutils literal notranslate"><span class="pre">init_fn</span></code> together with the value to be stored. The default is an empty tuple.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p><code class="docutils literal notranslate"><span class="pre">True</span></code> if the value has been stored successfully, <code class="docutils literal notranslate"><span class="pre">False</span></code> otherwise.</p> </dd> </dl> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.Module.tabulate"> <span class="sig-name descname"><span class="pre">tabulate</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">rngs</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">args</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">depth</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">show_repeated</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">mutable</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">DenyList(deny='intermediates')</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">console_kwargs</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">table_kwargs</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">mappingproxy({})</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">column_kwargs</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">mappingproxy({})</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">compute_flops</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">compute_vjp_flops</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/module.html#Module.tabulate"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.Module.tabulate" title="Permalink to this definition">#</a></dt> <dd><p>Creates a summary of the Module represented as a table.</p> <p>This method has the same signature and internally calls <code class="docutils literal notranslate"><span class="pre">Module.init</span></code>, but instead of returning the variables, it returns the string summarizing the Module in a table. <code class="docutils literal notranslate"><span class="pre">tabulate</span></code> uses <code class="docutils literal notranslate"><span class="pre">jax.eval_shape</span></code> to run the forward computation without consuming any FLOPs or allocating memory.</p> <p>Additional arguments can be passed into the <code class="docutils literal notranslate"><span class="pre">console_kwargs</span></code> argument, for example, <code class="docutils literal notranslate"><span class="pre">{'width':</span> <span class="pre">120}</span></code>. For a full list of <code class="docutils literal notranslate"><span class="pre">console_kwargs</span></code> arguments, see: <a class="reference external" href="https://rich.readthedocs.io/en/stable/reference/console.html#rich.console.Console">https://rich.readthedocs.io/en/stable/reference/console.html#rich.console.Console</a></p> <p>Example:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="kn">import</span> <span class="nn">flax.linen</span> <span class="k">as</span> <span class="nn">nn</span> <span class="gp">>>> </span><span class="kn">import</span> <span class="nn">jax</span><span class="o">,</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span> <span class="gp">>>> </span><span class="k">class</span> <span class="nc">Foo</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="gp">... </span> <span class="nd">@nn</span><span class="o">.</span><span class="n">compact</span> <span class="gp">... </span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="gp">... </span> <span class="n">h</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">4</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="gp">... </span> <span class="k">return</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">2</span><span class="p">)(</span><span class="n">h</span><span class="p">)</span> <span class="gp">>>> </span><span class="n">x</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">16</span><span class="p">,</span> <span class="mi">9</span><span class="p">))</span> <span class="gp">>>> </span><span class="c1"># print(Foo().tabulate(</span> <span class="gp">>>> </span><span class="c1"># jax.random.key(0), x, compute_flops=True, compute_vjp_flops=True))</span> </pre></div> </div> <p>This gives the following output:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span> Foo Summary ┏━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓ ┃ path ┃ module ┃ inputs ┃ outputs ┃ flops ┃ vjp_flops ┃ params ┃ ┡━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩ │ │ Foo │ float32[16,9] │ float32[16,2] │ 1504 │ 4460 │ │ ├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤ │ Dense_0 │ Dense │ float32[16,9] │ float32[16,4] │ 1216 │ 3620 │ bias: │ │ │ │ │ │ │ │ float32[4] │ │ │ │ │ │ │ │ kernel: │ │ │ │ │ │ │ │ float32[9,4] │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ 40 (160 B) │ ├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤ │ Dense_1 │ Dense │ float32[16,4] │ float32[16,2] │ 288 │ 840 │ bias: │ │ │ │ │ │ │ │ float32[2] │ │ │ │ │ │ │ │ kernel: │ │ │ │ │ │ │ │ float32[4,2] │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ 10 (40 B) │ ├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤ │ │ │ │ │ │ Total │ 50 (200 B) │ └─────────┴────────┴───────────────┴───────────────┴───────┴───────────┴─────────────────┘ Total Parameters: 50 (200 B) </pre></div> </div> <p><strong>Note</strong>: rows order in the table does not represent execution order, instead it aligns with the order of keys in <code class="docutils literal notranslate"><span class="pre">variables</span></code> which are sorted alphabetically.</p> <p><strong>Note</strong>: <code class="docutils literal notranslate"><span class="pre">vjp_flops</span></code> returns <code class="docutils literal notranslate"><span class="pre">0</span></code> if the module is not differentiable.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>rngs</strong> – The rngs for the variable collections as passed to <code class="docutils literal notranslate"><span class="pre">Module.init</span></code>.</p></li> <li><p><strong>*args</strong> – The arguments to the forward computation.</p></li> <li><p><strong>depth</strong> – controls how many submodule deep the summary can go. By default, its <code class="docutils literal notranslate"><span class="pre">None</span></code> which means no limit. If a submodule is not shown because of the depth limit, its parameter count and bytes will be added to the row of its first shown ancestor such that the sum of all rows always adds up to the total number of parameters of the Module.</p></li> <li><p><strong>show_repeated</strong> – If <code class="docutils literal notranslate"><span class="pre">True</span></code>, repeated calls to the same module will be shown in the table, otherwise only the first call will be shown. Default is <code class="docutils literal notranslate"><span class="pre">False</span></code>.</p></li> <li><p><strong>mutable</strong> – Can be bool, str, or list. Specifies which collections should be treated as mutable: <code class="docutils literal notranslate"><span class="pre">bool</span></code>: all/no collections are mutable. <code class="docutils literal notranslate"><span class="pre">str</span></code>: The name of a single mutable collection. <code class="docutils literal notranslate"><span class="pre">list</span></code>: A list of names of mutable collections. By default, all collections except ‘intermediates’ are mutable.</p></li> <li><p><strong>console_kwargs</strong> – An optional dictionary with additional keyword arguments that are passed to <code class="docutils literal notranslate"><span class="pre">rich.console.Console</span></code> when rendering the table. Default arguments are <code class="docutils literal notranslate"><span class="pre">{'force_terminal':</span> <span class="pre">True,</span> <span class="pre">'force_jupyter':</span> <span class="pre">False}</span></code>.</p></li> <li><p><strong>table_kwargs</strong> – An optional dictionary with additional keyword arguments that are passed to <code class="docutils literal notranslate"><span class="pre">rich.table.Table</span></code> constructor.</p></li> <li><p><strong>column_kwargs</strong> – An optional dictionary with additional keyword arguments that are passed to <code class="docutils literal notranslate"><span class="pre">rich.table.Table.add_column</span></code> when adding columns to the table.</p></li> <li><p><strong>compute_flops</strong> – whether to include a <code class="docutils literal notranslate"><span class="pre">flops</span></code> column in the table listing the estimated FLOPs cost of each module forward pass. Does incur actual on-device computation / compilation / memory allocation, but still introduces overhead for large modules (e.g. extra 20 seconds for a Stable Diffusion’s UNet, whereas otherwise tabulation would finish in 5 seconds).</p></li> <li><p><strong>compute_vjp_flops</strong> – whether to include a <code class="docutils literal notranslate"><span class="pre">vjp_flops</span></code> column in the table listing the estimated FLOPs cost of each module backward pass. Introduces a compute overhead of about 2-3X of <code class="docutils literal notranslate"><span class="pre">compute_flops</span></code>.</p></li> <li><p><strong>**kwargs</strong> – keyword arguments to pass to the forward computation.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>A string summarizing the Module.</p> </dd> </dl> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.Module.unbind"> <span class="sig-name descname"><span class="pre">unbind</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/module.html#Module.unbind"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.Module.unbind" title="Permalink to this definition">#</a></dt> <dd><p>Returns an unbound copy of a Module and its variables.</p> <p><code class="docutils literal notranslate"><span class="pre">unbind</span></code> helps create a stateless version of a bound Module.</p> <p>An example of a common use case: to extract a sub-Module defined inside <code class="docutils literal notranslate"><span class="pre">setup()</span></code> and its corresponding variables: 1) temporarily <code class="docutils literal notranslate"><span class="pre">bind</span></code> the parent Module; and then 2) <code class="docutils literal notranslate"><span class="pre">unbind</span></code> the desired sub-Module. (Recall that <code class="docutils literal notranslate"><span class="pre">setup()</span></code> is only called when the Module is bound.):</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="k">class</span> <span class="nc">Encoder</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="gp">... </span> <span class="nd">@nn</span><span class="o">.</span><span class="n">compact</span> <span class="gp">... </span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="gp">... </span> <span class="o">...</span> <span class="gp">... </span> <span class="k">return</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">256</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="k">class</span> <span class="nc">Decoder</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="gp">... </span> <span class="nd">@nn</span><span class="o">.</span><span class="n">compact</span> <span class="gp">... </span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="gp">... </span> <span class="o">...</span> <span class="gp">... </span> <span class="k">return</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">784</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="gp">>>> </span><span class="k">class</span> <span class="nc">AutoEncoder</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="gp">... </span> <span class="k">def</span> <span class="nf">setup</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="gp">... </span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span> <span class="o">=</span> <span class="n">Encoder</span><span class="p">()</span> <span class="gp">... </span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span> <span class="o">=</span> <span class="n">Decoder</span><span class="p">()</span> <span class="gp">...</span> <span class="gp">... </span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="gp">... </span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="p">(</span><span class="n">x</span><span class="p">))</span> <span class="gp">>>> </span><span class="n">module</span> <span class="o">=</span> <span class="n">AutoEncoder</span><span class="p">()</span> <span class="gp">>>> </span><span class="n">variables</span> <span class="o">=</span> <span class="n">module</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">784</span><span class="p">)))</span> <span class="gp">>>> </span><span class="c1"># Extract the Encoder sub-Module and its variables</span> <span class="gp">>>> </span><span class="n">encoder</span><span class="p">,</span> <span class="n">encoder_vars</span> <span class="o">=</span> <span class="n">module</span><span class="o">.</span><span class="n">bind</span><span class="p">(</span><span class="n">variables</span><span class="p">)</span><span class="o">.</span><span class="n">encoder</span><span class="o">.</span><span class="n">unbind</span><span class="p">()</span> </pre></div> </div> <dl class="field-list simple"> <dt class="field-odd">Returns</dt> <dd class="field-odd"><p>A tuple with an unbound copy of this Module and its variables.</p> </dd> </dl> </dd></dl> <dl class="py method"> <dt class="sig sig-object py" id="flax.linen.Module.variable"> <span class="sig-name descname"><span class="pre">variable</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">col</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">name</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">init_fn</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">None</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">*</span></span><span class="n"><span class="pre">init_args</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">unbox</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">True</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">**</span></span><span class="n"><span class="pre">init_kwargs</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/module.html#Module.variable"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.Module.variable" title="Permalink to this definition">#</a></dt> <dd><p>Declares and returns a variable in this Module.</p> <p>See <a class="reference internal" href="variable.html#module-flax.core.variables" title="flax.core.variables"><code class="xref py py-mod docutils literal notranslate"><span class="pre">flax.core.variables</span></code></a> for more information. See also <a class="reference internal" href="#flax.linen.Module.param" title="flax.linen.Module.param"><code class="xref py py-meth docutils literal notranslate"><span class="pre">param()</span></code></a> for a shorthand way to define read-only variables in the “params” collection.</p> <p>Contrary to <a class="reference internal" href="#flax.linen.Module.param" title="flax.linen.Module.param"><code class="xref py py-meth docutils literal notranslate"><span class="pre">param()</span></code></a>, all arguments passing using <code class="docutils literal notranslate"><span class="pre">init_fn</span></code> should be passed on explicitly:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="k">class</span> <span class="nc">Foo</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="gp">... </span> <span class="nd">@nn</span><span class="o">.</span><span class="n">compact</span> <span class="gp">... </span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="gp">... </span> <span class="n">x</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">4</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="gp">... </span> <span class="n">key</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">make_rng</span><span class="p">(</span><span class="s1">'stats'</span><span class="p">)</span> <span class="gp">... </span> <span class="n">mean</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">variable</span><span class="p">(</span><span class="s1">'stats'</span><span class="p">,</span> <span class="s1">'mean'</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">initializers</span><span class="o">.</span><span class="n">lecun_normal</span><span class="p">(),</span> <span class="n">key</span><span class="p">,</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="gp">... </span> <span class="o">...</span> <span class="gp">... </span> <span class="k">return</span> <span class="n">x</span> <span class="o">*</span> <span class="n">mean</span><span class="o">.</span><span class="n">value</span> <span class="gp">>>> </span><span class="n">variables</span> <span class="o">=</span> <span class="n">Foo</span><span class="p">()</span><span class="o">.</span><span class="n">init</span><span class="p">({</span><span class="s1">'params'</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="s1">'stats'</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">1</span><span class="p">)},</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">)))</span> <span class="gp">>>> </span><span class="n">jax</span><span class="o">.</span><span class="n">tree_util</span><span class="o">.</span><span class="n">tree_map</span><span class="p">(</span><span class="n">jnp</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">variables</span><span class="p">)</span> <span class="go">{'params': {'Dense_0': {'bias': (4,), 'kernel': (3, 4)}}, 'stats': {'mean': (2, 4)}}</span> </pre></div> </div> <p>In the example above, the function <code class="docutils literal notranslate"><span class="pre">lecun_normal</span></code> expects two arguments: <code class="docutils literal notranslate"><span class="pre">key</span></code> and <code class="docutils literal notranslate"><span class="pre">shape</span></code>, and both have to be passed on. The PRNG for <code class="docutils literal notranslate"><span class="pre">stats</span></code> has to be provided explicitly when calling <a class="reference internal" href="#flax.linen.init" title="flax.linen.init"><code class="xref py py-meth docutils literal notranslate"><span class="pre">init()</span></code></a> and <a class="reference internal" href="#flax.linen.apply" title="flax.linen.apply"><code class="xref py py-meth docutils literal notranslate"><span class="pre">apply()</span></code></a>.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>col</strong> – The variable collection name.</p></li> <li><p><strong>name</strong> – The variable name.</p></li> <li><p><strong>init_fn</strong> – The function that will be called to compute the initial value of this variable. This function will only be called the first time this variable is used in this module. If None, the variable must already be initialized otherwise an error is raised.</p></li> <li><p><strong>*init_args</strong> – The positional arguments to pass to init_fn.</p></li> <li><p><strong>unbox</strong> – If True, <code class="docutils literal notranslate"><span class="pre">AxisMetadata</span></code> instances are replaced by their unboxed value, see <code class="docutils literal notranslate"><span class="pre">flax.nn.meta.unbox</span></code> (default: True).</p></li> <li><p><strong>**init_kwargs</strong> – The key-word arguments to pass to init_fn</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>A <code class="xref py py-class docutils literal notranslate"><span class="pre">flax.core.variables.Variable</span></code> that can be read or set via “.value” attribute. Throws an error if the variable exists already.</p> </dd> </dl> </dd></dl> <dl class="py property"> <dt class="sig sig-object py" id="flax.linen.Module.variables"> <em class="property"><span class="pre">property</span><span class="w"> </span></em><span class="sig-name descname"><span class="pre">variables</span></span><a class="headerlink" href="#flax.linen.Module.variables" title="Permalink to this definition">#</a></dt> <dd><p>Returns the variables in this module.</p> </dd></dl> </dd></dl> <dl class="py function"> <dt class="sig sig-object py" id="flax.linen.apply"> <span class="sig-prename descclassname"><span class="pre">flax.linen.</span></span><span class="sig-name descname"><span class="pre">apply</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">fn</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">module</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">mutable</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">capture_intermediates</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/module.html#apply"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.apply" title="Permalink to this definition">#</a></dt> <dd><p>Creates an apply function to call <code class="docutils literal notranslate"><span class="pre">fn</span></code> with a bound module.</p> <p>Unlike <code class="docutils literal notranslate"><span class="pre">Module.apply</span></code> this function returns a new function with the signature <code class="docutils literal notranslate"><span class="pre">(variables,</span> <span class="pre">*args,</span> <span class="pre">rngs=None,</span> <span class="pre">**kwargs)</span> <span class="pre">-></span> <span class="pre">T</span></code> where <code class="docutils literal notranslate"><span class="pre">T</span></code> is the return type of <code class="docutils literal notranslate"><span class="pre">fn</span></code>. If <code class="docutils literal notranslate"><span class="pre">mutable</span></code> is not <code class="docutils literal notranslate"><span class="pre">False</span></code> the return type is a tuple where the second item is a <code class="docutils literal notranslate"><span class="pre">FrozenDict</span></code> with the mutated variables.</p> <p>The apply function that is returned can be directly composed with JAX transformations like <code class="docutils literal notranslate"><span class="pre">jax.jit</span></code>:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="k">class</span> <span class="nc">Foo</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="gp">... </span> <span class="k">def</span> <span class="nf">encode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="gp">... </span> <span class="o">...</span> <span class="gp">... </span> <span class="k">def</span> <span class="nf">decode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="gp">... </span> <span class="o">...</span> <span class="gp">>>> </span><span class="k">def</span> <span class="nf">f</span><span class="p">(</span><span class="n">foo</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="gp">... </span> <span class="n">z</span> <span class="o">=</span> <span class="n">foo</span><span class="o">.</span><span class="n">encode</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="gp">... </span> <span class="n">y</span> <span class="o">=</span> <span class="n">foo</span><span class="o">.</span><span class="n">decode</span><span class="p">(</span><span class="n">z</span><span class="p">)</span> <span class="gp">... </span> <span class="c1"># ...</span> <span class="gp">... </span> <span class="k">return</span> <span class="n">y</span> <span class="gp">>>> </span><span class="n">variables</span> <span class="o">=</span> <span class="p">{}</span> <span class="gp">>>> </span><span class="n">foo</span> <span class="o">=</span> <span class="n">Foo</span><span class="p">()</span> <span class="gp">>>> </span><span class="n">f_jitted</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">jit</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">f</span><span class="p">,</span> <span class="n">foo</span><span class="p">))</span> <span class="gp">>>> </span><span class="n">f_jitted</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">)))</span> </pre></div> </div> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>fn</strong> – The function that should be applied. The first argument passed will be a module instance of the <code class="docutils literal notranslate"><span class="pre">module</span></code> with variables and RNGs bound to it.</p></li> <li><p><strong>module</strong> – The <code class="docutils literal notranslate"><span class="pre">Module</span></code> that will be used to bind variables and RNGs to. The <code class="docutils literal notranslate"><span class="pre">Module</span></code> passed as the first argument to <code class="docutils literal notranslate"><span class="pre">fn</span></code> will be a clone of module.</p></li> <li><p><strong>mutable</strong> – Can be bool, str, or list. Specifies which collections should be treated as mutable: <code class="docutils literal notranslate"><span class="pre">bool</span></code>: all/no collections are mutable. <code class="docutils literal notranslate"><span class="pre">str</span></code>: The name of a single mutable collection. <code class="docutils literal notranslate"><span class="pre">list</span></code>: A list of names of mutable collections.</p></li> <li><p><strong>capture_intermediates</strong> – If <code class="docutils literal notranslate"><span class="pre">True</span></code>, captures intermediate return values of all Modules inside the “intermediates” collection. By default, only the return values of all <cite>__call__</cite> methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>The apply function wrapping <code class="docutils literal notranslate"><span class="pre">fn</span></code>.</p> </dd> </dl> </dd></dl> <dl class="py function"> <dt class="sig sig-object py" id="flax.linen.init"> <span class="sig-prename descclassname"><span class="pre">flax.linen.</span></span><span class="sig-name descname"><span class="pre">init</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">fn</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">module</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">mutable</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">DenyList(deny='intermediates')</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">capture_intermediates</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/module.html#init"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.init" title="Permalink to this definition">#</a></dt> <dd><p>Creates an init function to call <code class="docutils literal notranslate"><span class="pre">fn</span></code> with a bound module.</p> <p>Unlike <code class="docutils literal notranslate"><span class="pre">Module.init</span></code> this function returns a new function with the signature <code class="docutils literal notranslate"><span class="pre">(rngs,</span> <span class="pre">*args,</span> <span class="pre">**kwargs)</span> <span class="pre">-></span> <span class="pre">variables</span></code>. The rngs can be a dict of PRNGKeys or a single <code class="docutils literal notranslate"><span class="pre">`PRNGKey</span></code> which is equivalent to passing a dict with one PRNGKey with the name “params”.</p> <p>The init function that is returned can be directly composed with JAX transformations like <code class="docutils literal notranslate"><span class="pre">jax.jit</span></code>:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="k">class</span> <span class="nc">Foo</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="gp">... </span> <span class="k">def</span> <span class="nf">encode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="gp">... </span> <span class="o">...</span> <span class="gp">... </span> <span class="k">def</span> <span class="nf">decode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="gp">... </span> <span class="o">...</span> <span class="gp">>>> </span><span class="k">def</span> <span class="nf">f</span><span class="p">(</span><span class="n">foo</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="gp">... </span> <span class="n">z</span> <span class="o">=</span> <span class="n">foo</span><span class="o">.</span><span class="n">encode</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="gp">... </span> <span class="n">y</span> <span class="o">=</span> <span class="n">foo</span><span class="o">.</span><span class="n">decode</span><span class="p">(</span><span class="n">z</span><span class="p">)</span> <span class="gp">... </span> <span class="c1"># ...</span> <span class="gp">... </span> <span class="k">return</span> <span class="n">y</span> <span class="gp">>>> </span><span class="n">foo</span> <span class="o">=</span> <span class="n">Foo</span><span class="p">()</span> <span class="gp">>>> </span><span class="n">f_jitted</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">jit</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">f</span><span class="p">,</span> <span class="n">foo</span><span class="p">))</span> <span class="gp">>>> </span><span class="n">variables</span> <span class="o">=</span> <span class="n">f_jitted</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">)))</span> </pre></div> </div> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>fn</strong> – The function that should be applied. The first argument passed will be a module instance of the <code class="docutils literal notranslate"><span class="pre">module</span></code> with variables and RNGs bound to it.</p></li> <li><p><strong>module</strong> – The <code class="docutils literal notranslate"><span class="pre">Module</span></code> that will be used to bind variables and RNGs to. The <code class="docutils literal notranslate"><span class="pre">Module</span></code> passed as the first argument to <code class="docutils literal notranslate"><span class="pre">fn</span></code> will be a clone of module.</p></li> <li><p><strong>mutable</strong> – Can be bool, str, or list. Specifies which collections should be treated as mutable: <code class="docutils literal notranslate"><span class="pre">bool</span></code>: all/no collections are mutable. <code class="docutils literal notranslate"><span class="pre">str</span></code>: The name of a single mutable collection. <code class="docutils literal notranslate"><span class="pre">list</span></code>: A list of names of mutable collections. By default, all collections except “intermediates” are mutable.</p></li> <li><p><strong>capture_intermediates</strong> – If <cite>True</cite>, captures intermediate return values of all Modules inside the “intermediates” collection. By default, only the return values of all <cite>__call__</cite> methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>The init function wrapping <code class="docutils literal notranslate"><span class="pre">fn</span></code>.</p> </dd> </dl> </dd></dl> <dl class="py function"> <dt class="sig sig-object py" id="flax.linen.init_with_output"> <span class="sig-prename descclassname"><span class="pre">flax.linen.</span></span><span class="sig-name descname"><span class="pre">init_with_output</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">fn</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">module</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">mutable</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">DenyList(deny='intermediates')</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">capture_intermediates</span></span><span class="o"><span class="pre">=</span></span><span class="default_value"><span class="pre">False</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/module.html#init_with_output"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.init_with_output" title="Permalink to this definition">#</a></dt> <dd><p>Creates an init function to call <code class="docutils literal notranslate"><span class="pre">fn</span></code> with a bound module that also returns the function outputs.</p> <p>Unlike <code class="docutils literal notranslate"><span class="pre">Module.init_with_output</span></code> this function returns a new function with the signature <code class="docutils literal notranslate"><span class="pre">(rngs,</span> <span class="pre">*args,</span> <span class="pre">**kwargs)</span> <span class="pre">-></span> <span class="pre">(T,</span> <span class="pre">variables)</span></code> where <code class="docutils literal notranslate"><span class="pre">T</span></code> is the return type of <code class="docutils literal notranslate"><span class="pre">fn</span></code>. The rngs can be a dict of PRNGKeys or a single <code class="docutils literal notranslate"><span class="pre">`PRNGKey</span></code> which is equivalent to passing a dict with one PRNGKey with the name “params”.</p> <p>The init function that is returned can be directly composed with JAX transformations like <code class="docutils literal notranslate"><span class="pre">jax.jit</span></code>:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="k">class</span> <span class="nc">Foo</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="gp">... </span> <span class="k">def</span> <span class="nf">encode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="gp">... </span> <span class="o">...</span> <span class="gp">... </span> <span class="k">def</span> <span class="nf">decode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="gp">... </span> <span class="o">...</span> <span class="gp">>>> </span><span class="k">def</span> <span class="nf">f</span><span class="p">(</span><span class="n">foo</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="gp">... </span> <span class="n">z</span> <span class="o">=</span> <span class="n">foo</span><span class="o">.</span><span class="n">encode</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="gp">... </span> <span class="n">y</span> <span class="o">=</span> <span class="n">foo</span><span class="o">.</span><span class="n">decode</span><span class="p">(</span><span class="n">z</span><span class="p">)</span> <span class="gp">... </span> <span class="c1"># ...</span> <span class="gp">... </span> <span class="k">return</span> <span class="n">y</span> <span class="gp">>>> </span><span class="n">foo</span> <span class="o">=</span> <span class="n">Foo</span><span class="p">()</span> <span class="gp">>>> </span><span class="n">f_jitted</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">jit</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">init_with_output</span><span class="p">(</span><span class="n">f</span><span class="p">,</span> <span class="n">foo</span><span class="p">))</span> <span class="gp">>>> </span><span class="n">y</span><span class="p">,</span> <span class="n">variables</span> <span class="o">=</span> <span class="n">f_jitted</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">)))</span> </pre></div> </div> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>fn</strong> – The function that should be applied. The first argument passed will be a module instance of the <code class="docutils literal notranslate"><span class="pre">module</span></code> with variables and RNGs bound to it.</p></li> <li><p><strong>module</strong> – The <code class="docutils literal notranslate"><span class="pre">Module</span></code> that will be used to bind variables and RNGs to. The <code class="docutils literal notranslate"><span class="pre">Module</span></code> passed as the first argument to <code class="docutils literal notranslate"><span class="pre">fn</span></code> will be a clone of module.</p></li> <li><p><strong>mutable</strong> – Can be bool, str, or list. Specifies which collections should be treated as mutable: <code class="docutils literal notranslate"><span class="pre">bool</span></code>: all/no collections are mutable. <code class="docutils literal notranslate"><span class="pre">str</span></code>: The name of a single mutable collection. <code class="docutils literal notranslate"><span class="pre">list</span></code>: A list of names of mutable collections. By default, all collections except “intermediates” are mutable.</p></li> <li><p><strong>capture_intermediates</strong> – If <code class="docutils literal notranslate"><span class="pre">True</span></code>, captures intermediate return values of all Modules inside the “intermediates” collection. By default, only the return values of all <cite>__call__</cite> methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>The init function wrapping <code class="docutils literal notranslate"><span class="pre">fn</span></code>.</p> </dd> </dl> </dd></dl> <dl class="py function"> <dt class="sig sig-object py" id="flax.linen.intercept_methods"> <span class="sig-prename descclassname"><span class="pre">flax.linen.</span></span><span class="sig-name descname"><span class="pre">intercept_methods</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">interceptor</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/module.html#intercept_methods"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.intercept_methods" title="Permalink to this definition">#</a></dt> <dd><p>Registers a new method interceptor.</p> <p>Method interceptors allow you to (at a distance) intercept method calls to modules. It works similarly to decorators. You could modify args/kwargs before calling the underlying method and/or modify the result returning from calling the underlying method. Or you could completely skip calling the underlying method and decide to do something differently. For example:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="kn">import</span> <span class="nn">flax.linen</span> <span class="k">as</span> <span class="nn">nn</span> <span class="gp">>>> </span><span class="kn">import</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span> <span class="gp">...</span> <span class="gp">>>> </span><span class="k">class</span> <span class="nc">Foo</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="gp">... </span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="gp">... </span> <span class="k">return</span> <span class="n">x</span> <span class="gp">...</span> <span class="gp">>>> </span><span class="k">def</span> <span class="nf">my_interceptor1</span><span class="p">(</span><span class="n">next_fun</span><span class="p">,</span> <span class="n">args</span><span class="p">,</span> <span class="n">kwargs</span><span class="p">,</span> <span class="n">context</span><span class="p">):</span> <span class="gp">... </span> <span class="nb">print</span><span class="p">(</span><span class="s1">'calling my_interceptor1'</span><span class="p">)</span> <span class="gp">... </span> <span class="k">return</span> <span class="n">next_fun</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="gp">...</span> <span class="gp">>>> </span><span class="n">foo</span> <span class="o">=</span> <span class="n">Foo</span><span class="p">()</span> <span class="gp">>>> </span><span class="k">with</span> <span class="n">nn</span><span class="o">.</span><span class="n">intercept_methods</span><span class="p">(</span><span class="n">my_interceptor1</span><span class="p">):</span> <span class="gp">... </span> <span class="n">_</span> <span class="o">=</span> <span class="n">foo</span><span class="p">(</span><span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">([</span><span class="mi">1</span><span class="p">]))</span> <span class="go">calling my_interceptor1</span> </pre></div> </div> <p>You could also register multiple interceptors on the same method. Interceptors will run in order. For example:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="k">def</span> <span class="nf">my_interceptor2</span><span class="p">(</span><span class="n">next_fun</span><span class="p">,</span> <span class="n">args</span><span class="p">,</span> <span class="n">kwargs</span><span class="p">,</span> <span class="n">context</span><span class="p">):</span> <span class="gp">... </span> <span class="nb">print</span><span class="p">(</span><span class="s1">'calling my_interceptor2'</span><span class="p">)</span> <span class="gp">... </span> <span class="k">return</span> <span class="n">next_fun</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="gp">...</span> <span class="gp">>>> </span><span class="k">with</span> <span class="n">nn</span><span class="o">.</span><span class="n">intercept_methods</span><span class="p">(</span><span class="n">my_interceptor1</span><span class="p">),</span> \ <span class="gp">... </span> <span class="n">nn</span><span class="o">.</span><span class="n">intercept_methods</span><span class="p">(</span><span class="n">my_interceptor2</span><span class="p">):</span> <span class="gp">... </span> <span class="n">_</span> <span class="o">=</span> <span class="n">foo</span><span class="p">(</span><span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">([</span><span class="mi">1</span><span class="p">]))</span> <span class="go">calling my_interceptor1</span> <span class="go">calling my_interceptor2</span> </pre></div> </div> <p>You could skip other interceptors by directly calling the <code class="docutils literal notranslate"><span class="pre">context.orig_method</span></code>. For example:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="k">def</span> <span class="nf">my_interceptor3</span><span class="p">(</span><span class="n">next_fun</span><span class="p">,</span> <span class="n">args</span><span class="p">,</span> <span class="n">kwargs</span><span class="p">,</span> <span class="n">context</span><span class="p">):</span> <span class="gp">... </span> <span class="nb">print</span><span class="p">(</span><span class="s1">'calling my_interceptor3'</span><span class="p">)</span> <span class="gp">... </span> <span class="k">return</span> <span class="n">context</span><span class="o">.</span><span class="n">orig_method</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="gp">>>> </span><span class="k">with</span> <span class="n">nn</span><span class="o">.</span><span class="n">intercept_methods</span><span class="p">(</span><span class="n">my_interceptor3</span><span class="p">),</span> \ <span class="gp">... </span> <span class="n">nn</span><span class="o">.</span><span class="n">intercept_methods</span><span class="p">(</span><span class="n">my_interceptor1</span><span class="p">),</span> \ <span class="gp">... </span> <span class="n">nn</span><span class="o">.</span><span class="n">intercept_methods</span><span class="p">(</span><span class="n">my_interceptor2</span><span class="p">):</span> <span class="gp">... </span> <span class="n">_</span> <span class="o">=</span> <span class="n">foo</span><span class="p">(</span><span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">([</span><span class="mi">1</span><span class="p">]))</span> <span class="go">calling my_interceptor3</span> </pre></div> </div> <p>The following methods couldn’t be intercepted:</p> <ol class="arabic simple"> <li><p>Methods decoratored with <code class="docutils literal notranslate"><span class="pre">nn.nowrap</span></code>.</p></li> <li><p>Dunder methods including <code class="docutils literal notranslate"><span class="pre">__eq__</span></code>, <code class="docutils literal notranslate"><span class="pre">__repr__</span></code>, <code class="docutils literal notranslate"><span class="pre">__init__</span></code>, <code class="docutils literal notranslate"><span class="pre">__hash__</span></code>, and <code class="docutils literal notranslate"><span class="pre">__post_init__</span></code>.</p></li> <li><p>Module dataclass fields.</p></li> <li><p>Module descriptors.</p></li> </ol> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><p><strong>interceptor</strong> – A method interceptor.</p> </dd> </dl> </dd></dl> <dl class="py function"> <dt class="sig sig-object py" id="flax.linen.share_scope"> <span class="sig-prename descclassname"><span class="pre">flax.linen.</span></span><span class="sig-name descname"><span class="pre">share_scope</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">module</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">other</span></span></em>, <em class="sig-param"><span class="o"><span class="pre">/</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/module.html#share_scope"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.share_scope" title="Permalink to this definition">#</a></dt> <dd><p>Modifies one of the Modules such that they share the same scope. This is useful when you want to wrap a Module and extend its functionality without changing the parameter structure.</p> <p><code class="docutils literal notranslate"><span class="pre">share_scope</span></code> takes two Modules, <code class="docutils literal notranslate"><span class="pre">module</span></code> and <code class="docutils literal notranslate"><span class="pre">other</span></code>. <code class="docutils literal notranslate"><span class="pre">module</span></code> will use <code class="docutils literal notranslate"><span class="pre">other</span></code>’s scope if <code class="docutils literal notranslate"><span class="pre">other</span></code> has a scope and its not a descendant of``module``’s scope:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="kn">import</span> <span class="nn">flax.linen</span> <span class="k">as</span> <span class="nn">nn</span> <span class="gp">>>> </span><span class="kn">import</span> <span class="nn">jax</span> <span class="gp">>>> </span><span class="kn">from</span> <span class="nn">jax</span> <span class="kn">import</span> <span class="n">numpy</span> <span class="k">as</span> <span class="n">jnp</span><span class="p">,</span> <span class="n">random</span> <span class="gp">...</span> <span class="gp">>>> </span><span class="k">class</span> <span class="nc">DenseLoRA</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="gp">... </span> <span class="n">base</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dense</span> <span class="gp">... </span> <span class="n">rank</span><span class="p">:</span> <span class="nb">int</span> <span class="gp">...</span> <span class="gp">... </span> <span class="k">def</span> <span class="nf">setup</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="gp">... </span> <span class="n">nn</span><span class="o">.</span><span class="n">share_scope</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">base</span><span class="p">)</span> <span class="gp">...</span> <span class="gp">... </span> <span class="nd">@nn</span><span class="o">.</span><span class="n">compact</span> <span class="gp">... </span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">Array</span><span class="p">):</span> <span class="gp">... </span> <span class="n">din</span><span class="p">,</span> <span class="n">dout</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">base</span><span class="o">.</span><span class="n">features</span> <span class="gp">... </span> <span class="n">A</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">param</span><span class="p">(</span><span class="s1">'A'</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">zeros_init</span><span class="p">(),</span> <span class="p">(</span><span class="n">din</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">rank</span><span class="p">))</span> <span class="gp">... </span> <span class="n">B</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">param</span><span class="p">(</span><span class="s1">'B'</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">zeros_init</span><span class="p">(),</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">rank</span><span class="p">,</span> <span class="n">dout</span><span class="p">))</span> <span class="gp">... </span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">base</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="o">+</span> <span class="n">x</span> <span class="o">@</span> <span class="n">A</span> <span class="o">@</span> <span class="n">B</span> <span class="gp">...</span> <span class="gp">>>> </span><span class="k">class</span> <span class="nc">Model</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="gp">... </span> <span class="nd">@nn</span><span class="o">.</span><span class="n">compact</span> <span class="gp">... </span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">Array</span><span class="p">):</span> <span class="gp">... </span> <span class="n">dense</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">10</span><span class="p">)</span> <span class="c1"># base scope</span> <span class="gp">... </span> <span class="k">return</span> <span class="n">DenseLoRA</span><span class="p">(</span><span class="n">dense</span><span class="p">,</span> <span class="n">rank</span><span class="o">=</span><span class="mi">2</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># reuse the base scope</span> <span class="gp">...</span> <span class="gp">>>> </span><span class="n">model</span> <span class="o">=</span> <span class="n">Model</span><span class="p">()</span> <span class="gp">...</span> <span class="gp">>>> </span><span class="n">params</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">5</span><span class="p">)))[</span><span class="s1">'params'</span><span class="p">]</span> <span class="gp">>>> </span><span class="nb">list</span><span class="p">(</span><span class="n">params</span><span class="p">[</span><span class="s1">'Dense_0'</span><span class="p">]</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span> <span class="go">['A', 'B', 'kernel', 'bias']</span> </pre></div> </div> <p>When <code class="docutils literal notranslate"><span class="pre">other</span></code>’s scope is a descendant of <code class="docutils literal notranslate"><span class="pre">module</span></code>’s scope then <code class="docutils literal notranslate"><span class="pre">other</span></code> will use <code class="docutils literal notranslate"><span class="pre">module</span></code>’s scope instead:</p> <div class="highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">>>> </span><span class="k">class</span> <span class="nc">DenseLoRA</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="gp">... </span> <span class="n">features</span><span class="p">:</span> <span class="nb">int</span> <span class="gp">... </span> <span class="n">rank</span><span class="p">:</span> <span class="nb">int</span> <span class="gp">...</span> <span class="gp">... </span> <span class="k">def</span> <span class="nf">setup</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="gp">... </span> <span class="bp">self</span><span class="o">.</span><span class="n">child</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">features</span><span class="p">)</span> <span class="gp">... </span> <span class="n">nn</span><span class="o">.</span><span class="n">share_scope</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">child</span><span class="p">)</span> <span class="gp">...</span> <span class="gp">... </span> <span class="nd">@nn</span><span class="o">.</span><span class="n">compact</span> <span class="gp">... </span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">Array</span><span class="p">):</span> <span class="gp">... </span> <span class="n">din</span><span class="p">,</span> <span class="n">dout</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">features</span> <span class="gp">... </span> <span class="n">A</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">param</span><span class="p">(</span><span class="s1">'A'</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">zeros_init</span><span class="p">(),</span> <span class="p">(</span><span class="n">din</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">rank</span><span class="p">))</span> <span class="gp">... </span> <span class="n">B</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">param</span><span class="p">(</span><span class="s1">'B'</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">zeros_init</span><span class="p">(),</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">rank</span><span class="p">,</span> <span class="n">dout</span><span class="p">))</span> <span class="gp">... </span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">child</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="o">+</span> <span class="n">x</span> <span class="o">@</span> <span class="n">A</span> <span class="o">@</span> <span class="n">B</span> <span class="gp">...</span> <span class="gp">>>> </span><span class="k">class</span> <span class="nc">Model</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="gp">... </span> <span class="nd">@nn</span><span class="o">.</span><span class="n">compact</span> <span class="gp">... </span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">jax</span><span class="o">.</span><span class="n">Array</span><span class="p">):</span> <span class="gp">... </span> <span class="k">return</span> <span class="n">DenseLoRA</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="n">rank</span><span class="o">=</span><span class="mi">2</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="gp">...</span> <span class="gp">>>> </span><span class="n">model</span> <span class="o">=</span> <span class="n">Model</span><span class="p">()</span> <span class="gp">...</span> <span class="gp">>>> </span><span class="n">params</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">5</span><span class="p">)))[</span><span class="s1">'params'</span><span class="p">]</span> <span class="gp">>>> </span><span class="nb">list</span><span class="p">(</span><span class="n">params</span><span class="p">[</span><span class="s1">'DenseLoRA_0'</span><span class="p">]</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span> <span class="go">['A', 'B', 'kernel', 'bias']</span> </pre></div> </div> </dd></dl> </div> </article> <footer class="prev-next-footer d-print-none"> <div class="prev-next-area"> <a class="left-prev" href="index.html" title="previous page"> <i class="fa-solid fa-angle-left"></i> <div class="prev-next-info"> <p class="prev-next-subtitle">previous</p> <p class="prev-next-title">flax.linen</p> </div> </a> <a class="right-next" href="init_apply.html" title="next page"> <div class="prev-next-info"> <p class="prev-next-subtitle">next</p> <p class="prev-next-title">Init/Apply</p> </div> <i class="fa-solid fa-angle-right"></i> </a> </div> </footer> </div> <div class="bd-sidebar-secondary bd-toc"><div class="sidebar-secondary-items sidebar-secondary__inner"> <div class="sidebar-secondary-item"> <div class="page-toc tocsection onthispage"> <i class="fa-solid fa-list"></i> Contents </div> <nav class="bd-toc-nav page-toc"> <ul class="visible nav section-nav flex-column"> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Module"><code class="docutils literal notranslate"><span class="pre">Module</span></code></a><ul class="nav section-nav flex-column"> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Module.__setattr__"><code class="docutils literal notranslate"><span class="pre">Module.__setattr__()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Module.apply"><code class="docutils literal notranslate"><span class="pre">Module.apply()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Module.bind"><code class="docutils literal notranslate"><span class="pre">Module.bind()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Module.copy"><code class="docutils literal notranslate"><span class="pre">Module.copy()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Module.get_variable"><code class="docutils literal notranslate"><span class="pre">Module.get_variable()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Module.has_rng"><code class="docutils literal notranslate"><span class="pre">Module.has_rng()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Module.has_variable"><code class="docutils literal notranslate"><span class="pre">Module.has_variable()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Module.init"><code class="docutils literal notranslate"><span class="pre">Module.init()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Module.init_with_output"><code class="docutils literal notranslate"><span class="pre">Module.init_with_output()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Module.is_initializing"><code class="docutils literal notranslate"><span class="pre">Module.is_initializing()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Module.is_mutable_collection"><code class="docutils literal notranslate"><span class="pre">Module.is_mutable_collection()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Module.lazy_init"><code class="docutils literal notranslate"><span class="pre">Module.lazy_init()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Module.make_rng"><code class="docutils literal notranslate"><span class="pre">Module.make_rng()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Module.module_paths"><code class="docutils literal notranslate"><span class="pre">Module.module_paths()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Module.param"><code class="docutils literal notranslate"><span class="pre">Module.param()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Module.path"><code class="docutils literal notranslate"><span class="pre">Module.path</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Module.perturb"><code class="docutils literal notranslate"><span class="pre">Module.perturb()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Module.put_variable"><code class="docutils literal notranslate"><span class="pre">Module.put_variable()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Module.setup"><code class="docutils literal notranslate"><span class="pre">Module.setup()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Module.sow"><code class="docutils literal notranslate"><span class="pre">Module.sow()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Module.tabulate"><code class="docutils literal notranslate"><span class="pre">Module.tabulate()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Module.unbind"><code class="docutils literal notranslate"><span class="pre">Module.unbind()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Module.variable"><code class="docutils literal notranslate"><span class="pre">Module.variable()</span></code></a></li> <li class="toc-h3 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.Module.variables"><code class="docutils literal notranslate"><span class="pre">Module.variables</span></code></a></li> </ul> </li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.apply"><code class="docutils literal notranslate"><span class="pre">apply()</span></code></a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.init"><code class="docutils literal notranslate"><span class="pre">init()</span></code></a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.init_with_output"><code class="docutils literal notranslate"><span class="pre">init_with_output()</span></code></a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.intercept_methods"><code class="docutils literal notranslate"><span class="pre">intercept_methods()</span></code></a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.share_scope"><code class="docutils literal notranslate"><span class="pre">share_scope()</span></code></a></li> </ul> </nav></div> </div></div> </div> <footer class="bd-footer-content"> <div class="bd-footer-content__inner container"> <div class="footer-item"> <p class="component-author"> By The Flax authors </p> </div> <div class="footer-item"> <p class="copyright"> © Copyright 2023, The Flax authors. <br/> </p> </div> <div class="footer-item"> </div> <div class="footer-item"> </div> </div> </footer> </main> </div> </div> <!-- Scripts loaded after <body> so the DOM is not blocked --> <script src="../../_static/scripts/bootstrap.js?digest=dfe6caa3a7d634c4db9b"></script> <script src="../../_static/scripts/pydata-sphinx-theme.js?digest=dfe6caa3a7d634c4db9b"></script> <footer class="bd-footer"> </footer> </body> </html>