CINXE.COM

Migrating from Haiku to Flax

<!DOCTYPE html> <html lang="en" data-content_root="" > <head> <meta charset="utf-8" /> <meta name="viewport" content="width=device-width, initial-scale=1.0" /> <title>Migrating from Haiku to Flax</title> <script data-cfasync="false"> document.documentElement.dataset.mode = localStorage.getItem("mode") || ""; document.documentElement.dataset.theme = localStorage.getItem("theme") || ""; </script> <!-- Loaded before other Sphinx assets --> <link href="../../_static/styles/theme.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" /> <link href="../../_static/styles/bootstrap.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" /> <link href="../../_static/styles/pydata-sphinx-theme.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" /> <link href="../../_static/vendor/fontawesome/6.5.2/css/all.min.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" /> <link rel="preload" as="font" type="font/woff2" crossorigin href="../../_static/vendor/fontawesome/6.5.2/webfonts/fa-solid-900.woff2" /> <link rel="preload" as="font" type="font/woff2" crossorigin href="../../_static/vendor/fontawesome/6.5.2/webfonts/fa-brands-400.woff2" /> <link rel="preload" as="font" type="font/woff2" crossorigin href="../../_static/vendor/fontawesome/6.5.2/webfonts/fa-regular-400.woff2" /> <link rel="stylesheet" type="text/css" href="../../_static/pygments.css" /> <link rel="stylesheet" type="text/css" href="../../_static/styles/sphinx-book-theme.css" /> <link rel="stylesheet" type="text/css" href="../../_static/mystnb.4510f1fc1dee50b3e5859aac5469c37c29e427902b24a333a5f9fcb2f0b3ac41.css" /> <link rel="stylesheet" type="text/css" href="../../_static/sphinx-design.5ea377869091fd0449014c60fc090103.min.css" /> <link rel="stylesheet" type="text/css" href="../../_static/css/flax_theme.css" /> <!-- Pre-loaded scripts that we'll load fully later --> <link rel="preload" as="script" href="../../_static/scripts/bootstrap.js?digest=dfe6caa3a7d634c4db9b" /> <link rel="preload" as="script" href="../../_static/scripts/pydata-sphinx-theme.js?digest=dfe6caa3a7d634c4db9b" /> <script src="../../_static/vendor/fontawesome/6.5.2/js/all.min.js?digest=dfe6caa3a7d634c4db9b"></script> <script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script> <script src="../../_static/jquery.js"></script> <script src="../../_static/underscore.js"></script> <script src="../../_static/_sphinx_javascript_frameworks_compat.js"></script> <script src="../../_static/doctools.js"></script> <script src="../../_static/sphinx_highlight.js"></script> <script src="../../_static/scripts/sphinx-book-theme.js"></script> <script src="../../_static/design-tabs.js"></script> <script>DOCUMENTATION_OPTIONS.pagename = 'guides/converting_and_upgrading/haiku_migration_guide';</script> <link rel="shortcut icon" href="../../_static/flax.png"/> <link rel="index" title="Index" href="../../genindex.html" /> <link rel="search" title="Search" href="../../search.html" /> <link rel="next" title="Convert PyTorch models to Flax" href="convert_pytorch_to_flax.html" /> <link rel="prev" title="Converting and upgrading" href="index.html" /> <meta name="viewport" content="width=device-width, initial-scale=1"/> <meta name="docsearch:language" content="en"/> <script async type="text/javascript" src="/_/static/javascript/readthedocs-addons.js"></script><meta name="readthedocs-project-slug" content="flax-linen" /><meta name="readthedocs-version-slug" content="latest" /><meta name="readthedocs-resolver-filename" content="/guides/converting_and_upgrading/haiku_migration_guide.html" /><meta name="readthedocs-http-status" content="200" /></head> <body data-bs-spy="scroll" data-bs-target=".bd-toc-nav" data-offset="180" data-bs-root-margin="0px 0px -60%" data-default-mode=""> <div id="pst-skip-link" class="skip-link d-print-none"><a href="#main-content">Skip to main content</a></div> <div id="pst-scroll-pixel-helper"></div> <button type="button" class="btn rounded-pill" id="pst-back-to-top"> <i class="fa-solid fa-arrow-up"></i>Back to top</button> <input type="checkbox" class="sidebar-toggle" id="pst-primary-sidebar-checkbox"/> <label class="overlay overlay-primary" for="pst-primary-sidebar-checkbox"></label> <input type="checkbox" class="sidebar-toggle" id="pst-secondary-sidebar-checkbox"/> <label class="overlay overlay-secondary" for="pst-secondary-sidebar-checkbox"></label> <div class="search-button__wrapper"> <div class="search-button__overlay"></div> <div class="search-button__search-container"> <form class="bd-search d-flex align-items-center" action="../../search.html" method="get"> <i class="fa-solid fa-magnifying-glass"></i> <input type="search" class="form-control" name="q" id="search-input" placeholder="Search..." aria-label="Search..." autocomplete="off" autocorrect="off" autocapitalize="off" spellcheck="false"/> <span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd>K</kbd></span> </form></div> </div> <div class="pst-async-banner-revealer d-none"> <aside id="bd-header-version-warning" class="d-none d-print-none" aria-label="Version warning"></aside> </div> <aside class="bd-header-announcement" aria-label="Announcement"> <div class="bd-header-announcement__content"> <a href="https://flax.readthedocs.io/en/latest/index.html" style="text-decoration: none; color: white;" > This site covers the old Flax Linen API. <span style="color: lightgray;">[Explore the new <b>Flax NNX</b> API ✨]</span> </a> </div> </aside> <header class="bd-header navbar navbar-expand-lg bd-navbar d-print-none"> </header> <div class="bd-container"> <div class="bd-container__inner bd-page-width"> <div class="bd-sidebar-primary bd-sidebar"> <div class="sidebar-header-items sidebar-primary__section"> </div> <div class="sidebar-primary-items__start sidebar-primary__section"> <div class="sidebar-primary-item"> <a class="navbar-brand logo" href="../../index.html"> <img src="../../_static/flax.png" class="logo__image only-light" alt=" - Home"/> <script>document.write(`<img src="../../_static/flax.png" class="logo__image only-dark" alt=" - Home"/>`);</script> </a></div> <div class="sidebar-primary-item"> <script> document.write(` <button class="btn search-button-field search-button__button" title="Search" aria-label="Search" data-bs-placement="bottom" data-bs-toggle="tooltip"> <i class="fa-solid fa-magnifying-glass"></i> <span class="search-button__default-text">Search</span> <span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd class="kbd-shortcut__modifier">K</kbd></span> </button> `); </script></div> <div class="sidebar-primary-item"><nav class="bd-links bd-docs-nav" aria-label="Main"> <div class="bd-toc-item navbar-nav active"> <ul class="current nav bd-sidenav"> <li class="toctree-l1"><a class="reference internal" href="../../quick_start.html">Quick start</a></li> <li class="toctree-l1"><a class="reference internal" href="../flax_fundamentals/flax_basics.html">Flax Basics</a></li> <li class="toctree-l1 current active has-children"><a class="reference internal" href="../index.html">Guides</a><details open="open"><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul class="current"> <li class="toctree-l2 has-children"><a class="reference internal" href="../flax_fundamentals/index.html">Flax fundamentals</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l3"><a class="reference external" href="https://jax.readthedocs.io/en/latest/jax-101/index.html">JAX 101</a></li> <li class="toctree-l3"><a class="reference internal" href="../flax_fundamentals/flax_basics.html">Flax Basics</a></li> <li class="toctree-l3"><a class="reference internal" href="../flax_fundamentals/state_params.html">Managing Parameters and State</a></li> <li class="toctree-l3"><a class="reference internal" href="../flax_fundamentals/setup_or_nncompact.html"><code class="docutils literal notranslate"><span class="pre">setup</span></code> vs <code class="docutils literal notranslate"><span class="pre">compact</span></code></a></li> <li class="toctree-l3"><a class="reference internal" href="../flax_fundamentals/arguments.html">Dealing with Flax Module arguments</a></li> <li class="toctree-l3"><a class="reference internal" href="../flax_fundamentals/rng_guide.html">Randomness and PRNGs in Flax</a></li> </ul> </details></li> <li class="toctree-l2 has-children"><a class="reference internal" href="../data_preprocessing/index.html">Data preprocessing</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l3"><a class="reference internal" href="../data_preprocessing/full_eval.html">Processing the entire Dataset</a></li> <li class="toctree-l3"><a class="reference internal" href="../data_preprocessing/loading_datasets.html">Loading datasets</a></li> </ul> </details></li> <li class="toctree-l2 has-children"><a class="reference internal" href="../training_techniques/index.html">Training techniques</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l3"><a class="reference internal" href="../training_techniques/batch_norm.html">Batch normalization</a></li> <li class="toctree-l3"><a class="reference internal" href="../training_techniques/dropout.html">Dropout</a></li> <li class="toctree-l3"><a class="reference internal" href="../training_techniques/lr_schedule.html">Learning rate scheduling</a></li> <li class="toctree-l3"><a class="reference internal" href="../training_techniques/transfer_learning.html">Transfer learning</a></li> <li class="toctree-l3"><a class="reference internal" href="../training_techniques/use_checkpointing.html">Save and load checkpoints</a></li> </ul> </details></li> <li class="toctree-l2 has-children"><a class="reference internal" href="../parallel_training/index.html">Parallel training</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l3"><a class="reference internal" href="../parallel_training/ensembling.html">Ensembling on multiple devices</a></li> <li class="toctree-l3"><a class="reference internal" href="../parallel_training/flax_on_pjit.html">Scale up Flax Modules on multiple devices</a></li> </ul> </details></li> <li class="toctree-l2 has-children"><a class="reference internal" href="../model_inspection/index.html">Model inspection</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l3"><a class="reference internal" href="../model_inspection/model_surgery.html">Model surgery</a></li> <li class="toctree-l3"><a class="reference internal" href="../model_inspection/extracting_intermediates.html">Extracting intermediate values</a></li> </ul> </details></li> <li class="toctree-l2 current active has-children"><a class="reference internal" href="index.html">Converting and upgrading</a><details open="open"><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul class="current"> <li class="toctree-l3 current active"><a class="current reference internal" href="#">Migrating from Haiku to Flax</a></li> <li class="toctree-l3"><a class="reference internal" href="convert_pytorch_to_flax.html">Convert PyTorch models to Flax</a></li> <li class="toctree-l3"><a class="reference internal" href="orbax_upgrade_guide.html">Migrate checkpointing to Orbax</a></li> <li class="toctree-l3"><a class="reference internal" href="optax_update_guide.html">Upgrading my codebase to Optax</a></li> <li class="toctree-l3"><a class="reference internal" href="linen_upgrade_guide.html">Upgrading my codebase to Linen</a></li> <li class="toctree-l3"><a class="reference internal" href="rnncell_upgrade_guide.html">RNNCellBase Upgrade Guide</a></li> <li class="toctree-l3"><a class="reference internal" href="regular_dict_upgrade_guide.html">Migrate to regular dicts</a></li> </ul> </details></li> <li class="toctree-l2 has-children"><a class="reference internal" href="../quantization/index.html">Quantization</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l3"><a class="reference internal" href="../quantization/fp8_basics.html">User Guide on Using FP8</a></li> </ul> </details></li> <li class="toctree-l2"><a class="reference internal" href="../flax_sharp_bits.html">The Sharp Bits</a></li> </ul> </details></li> <li class="toctree-l1 has-children"><a class="reference internal" href="../../examples/index.html">Examples</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l2"><a class="reference internal" href="../../examples/core_examples.html">Core examples</a></li> <li class="toctree-l2"><a class="reference internal" href="../../examples/google_research_examples.html">Google Research examples</a></li> <li class="toctree-l2"><a class="reference internal" href="../../examples/repositories_that_use_flax.html">Repositories that use Flax</a></li> <li class="toctree-l2"><a class="reference internal" href="../../examples/community_examples.html">Community examples</a></li> </ul> </details></li> <li class="toctree-l1"><a class="reference internal" href="../../glossary.html">Glossary</a></li> <li class="toctree-l1"><a class="reference internal" href="../../faq.html">Frequently Asked Questions (FAQ)</a></li> <li class="toctree-l1 has-children"><a class="reference internal" href="../../developer_notes/index.html">Developer notes</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l2"><a class="reference internal" href="../../developer_notes/module_lifecycle.html">The Flax Module lifecycle</a></li> <li class="toctree-l2"><a class="reference internal" href="../../developer_notes/lift.html">Lifted transformations</a></li> <li class="toctree-l2"><a class="reference external" href="https://github.com/google/flax/tree/main/docs/flip">FLIPs</a></li> </ul> </details></li> <li class="toctree-l1"><a class="reference internal" href="../../philosophy.html">The Flax philosophy</a></li> <li class="toctree-l1"><a class="reference internal" href="../../contributing.html">How to contribute</a></li> <li class="toctree-l1 has-children"><a class="reference internal" href="../../api_reference/index.html">API Reference</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l2"><a class="reference internal" href="../../api_reference/flax.config.html">flax.config package</a></li> <li class="toctree-l2"><a class="reference internal" href="../../api_reference/flax.core.frozen_dict.html">flax.core.frozen_dict package</a></li> <li class="toctree-l2"><a class="reference internal" href="../../api_reference/flax.cursor.html">flax.cursor package</a></li> <li class="toctree-l2"><a class="reference internal" href="../../api_reference/flax.errors.html">flax.errors package</a></li> <li class="toctree-l2"><a class="reference internal" href="../../api_reference/flax.jax_utils.html">flax.jax_utils package</a></li> <li class="toctree-l2 has-children"><a class="reference internal" href="../../api_reference/flax.linen/index.html">flax.linen</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l3"><a class="reference internal" href="../../api_reference/flax.linen/module.html">Module</a></li> <li class="toctree-l3"><a class="reference internal" href="../../api_reference/flax.linen/init_apply.html">Init/Apply</a></li> <li class="toctree-l3"><a class="reference internal" href="../../api_reference/flax.linen/layers.html">Layers</a></li> <li class="toctree-l3"><a class="reference internal" href="../../api_reference/flax.linen/activation_functions.html">Activation functions</a></li> <li class="toctree-l3"><a class="reference internal" href="../../api_reference/flax.linen/initializers.html">Initializers</a></li> <li class="toctree-l3"><a class="reference internal" href="../../api_reference/flax.linen/transformations.html">Transformations</a></li> <li class="toctree-l3"><a class="reference internal" href="../../api_reference/flax.linen/inspection.html">Inspection</a></li> <li class="toctree-l3"><a class="reference internal" href="../../api_reference/flax.linen/variable.html">Variable dictionary</a></li> <li class="toctree-l3"><a class="reference internal" href="../../api_reference/flax.linen/spmd.html">SPMD</a></li> <li class="toctree-l3"><a class="reference internal" href="../../api_reference/flax.linen/decorators.html">Decorators</a></li> <li class="toctree-l3"><a class="reference internal" href="../../api_reference/flax.linen/profiling.html">Profiling</a></li> </ul> </details></li> <li class="toctree-l2"><a class="reference internal" href="../../api_reference/flax.serialization.html">flax.serialization package</a></li> <li class="toctree-l2"><a class="reference internal" href="../../api_reference/flax.struct.html">flax.struct package</a></li> <li class="toctree-l2"><a class="reference internal" href="../../api_reference/flax.traceback_util.html">flax.traceback_util package</a></li> <li class="toctree-l2"><a class="reference internal" href="../../api_reference/flax.training.html">flax.training package</a></li> <li class="toctree-l2"><a class="reference internal" href="../../api_reference/flax.traverse_util.html">flax.traverse_util package</a></li> </ul> </details></li> <li class="toctree-l1"><a class="reference external" href="https://flax.readthedocs.io/en/latest/index.html">Flax NNX</a></li> </ul> </div> </nav></div> </div> <div class="sidebar-primary-items__end sidebar-primary__section"> </div> <div id="rtd-footer-container"></div> </div> <main id="main-content" class="bd-main" role="main"> <div class="sbt-scroll-pixel-helper"></div> <div class="bd-content"> <div class="bd-article-container"> <div class="bd-header-article d-print-none"> <div class="header-article-items header-article__inner"> <div class="header-article-items__start"> <div class="header-article-item"><button class="sidebar-toggle primary-toggle btn btn-sm" title="Toggle primary sidebar" data-bs-placement="bottom" data-bs-toggle="tooltip"> <span class="fa-solid fa-bars"></span> </button></div> </div> <div class="header-article-items__end"> <div class="header-article-item"> <div class="article-header-buttons"> <a href="https://github.com/google/flax" target="_blank" class="btn btn-sm btn-source-repository-button" title="Source repository" data-bs-placement="bottom" data-bs-toggle="tooltip" > <span class="btn__icon-container"> <i class="fab fa-github"></i> </span> </a> <div class="dropdown dropdown-download-buttons"> <button class="btn dropdown-toggle" type="button" data-bs-toggle="dropdown" aria-expanded="false" aria-label="Download this page"> <i class="fas fa-download"></i> </button> <ul class="dropdown-menu"> <li><a href="../../_sources/guides/converting_and_upgrading/haiku_migration_guide.rst" target="_blank" class="btn btn-sm btn-download-source-button dropdown-item" title="Download source file" data-bs-placement="left" data-bs-toggle="tooltip" > <span class="btn__icon-container"> <i class="fas fa-file"></i> </span> <span class="btn__text-container">.rst</span> </a> </li> <li> <button onclick="window.print()" class="btn btn-sm btn-download-pdf-button dropdown-item" title="Print to PDF" data-bs-placement="left" data-bs-toggle="tooltip" > <span class="btn__icon-container"> <i class="fas fa-file-pdf"></i> </span> <span class="btn__text-container">.pdf</span> </button> </li> </ul> </div> <button onclick="toggleFullScreen()" class="btn btn-sm btn-fullscreen-button" title="Fullscreen mode" data-bs-placement="bottom" data-bs-toggle="tooltip" > <span class="btn__icon-container"> <i class="fas fa-expand"></i> </span> </button> <script> document.write(` <button class="btn btn-sm nav-link pst-navbar-icon theme-switch-button" title="light/dark" aria-label="light/dark" data-bs-placement="bottom" data-bs-toggle="tooltip"> <i class="theme-switch fa-solid fa-sun fa-lg" data-mode="light"></i> <i class="theme-switch fa-solid fa-moon fa-lg" data-mode="dark"></i> <i class="theme-switch fa-solid fa-circle-half-stroke fa-lg" data-mode="auto"></i> </button> `); </script> <script> document.write(` <button class="btn btn-sm pst-navbar-icon search-button search-button__button" title="Search" aria-label="Search" data-bs-placement="bottom" data-bs-toggle="tooltip"> <i class="fa-solid fa-magnifying-glass fa-lg"></i> </button> `); </script> <button class="sidebar-toggle secondary-toggle btn btn-sm" title="Toggle secondary sidebar" data-bs-placement="bottom" data-bs-toggle="tooltip"> <span class="fa-solid fa-list"></span> </button> </div></div> </div> </div> </div> <div id="jb-print-docs-body" class="onlyprint"> <h1>Migrating from Haiku to Flax</h1> <!-- Table of contents --> <div id="print-main-content"> <div id="jb-print-toc"> <div> <h2> Contents </h2> </div> <nav aria-label="Page"> <ul class="visible nav section-nav flex-column"> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#basic-example">Basic Example</a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#handling-state">Handling State</a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#using-multiple-methods">Using Multiple Methods</a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#lifted-transforms">Lifted Transforms</a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#scan-over-layers">Scan over layers</a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#top-level-haiku-functions-vs-top-level-flax-modules">Top-level Haiku functions vs top-level Flax modules</a></li> </ul> </nav> </div> </div> </div> <div id="searchbox"></div> <article class="bd-article"> <div class="section" id="migrating-from-haiku-to-flax"> <h1>Migrating from Haiku to Flax<a class="headerlink" href="#migrating-from-haiku-to-flax" title="Permalink to this heading">#</a></h1> <p>This guide will walk through the process of migrating Haiku models to Flax, and highlight the differences between the two libraries.</p> <div class="section" id="basic-example"> <h2>Basic Example<a class="headerlink" href="#basic-example" title="Permalink to this heading">#</a></h2> <p>To create custom Modules you subclass from a <code class="docutils literal notranslate"><span class="pre">Module</span></code> base class in both Haiku and Flax. However, Haiku classes use a regular <code class="docutils literal notranslate"><span class="pre">__init__</span></code> method whereas Flax classes are <code class="docutils literal notranslate"><span class="pre">dataclasses</span></code>, meaning you define some class attributes that are used to automatically generate a constructor. Also, all Flax Modules accept a <code class="docutils literal notranslate"><span class="pre">name</span></code> argument without needing to define it, whereas in Haiku <code class="docutils literal notranslate"><span class="pre">name</span></code> must be explicitly defined in the constructor signature and passed to the superclass constructor.</p> <p><div class="sd-tab-set docutils"> <input checked="checked" id="sd-tab-item-0" name="sd-tab-set-0" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="Haiku" for="sd-tab-item-0"> Haiku</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">haiku</span> <span class="k">as</span> <span class="nn">hk</span> <span class="k">class</span> <span class="nc">Block</span><span class="p">(</span><span class="n">hk</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">features</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">features</span> <span class="o">=</span> <span class="n">features</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">training</span><span class="p">:</span> <span class="nb">bool</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="n">hk</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">features</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">hk</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">hk</span><span class="o">.</span><span class="n">next_rng_key</span><span class="p">(),</span> <span class="mf">0.5</span> <span class="k">if</span> <span class="n">training</span> <span class="k">else</span> <span class="mi">0</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">x</span> <span class="k">class</span> <span class="nc">Model</span><span class="p">(</span><span class="n">hk</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dmid</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">dout</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">dmid</span> <span class="o">=</span> <span class="n">dmid</span> <span class="bp">self</span><span class="o">.</span><span class="n">dout</span> <span class="o">=</span> <span class="n">dout</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">training</span><span class="p">:</span> <span class="nb">bool</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="n">Block</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dmid</span><span class="p">)(</span><span class="n">x</span><span class="p">,</span> <span class="n">training</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">hk</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dout</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">x</span> </pre></div> </div> </div> <input id="sd-tab-item-1" name="sd-tab-set-0" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="Flax" for="sd-tab-item-1"> Flax</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">flax.linen</span> <span class="k">as</span> <span class="nn">nn</span> <span class="k">class</span> <span class="nc">Block</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="n">features</span><span class="p">:</span> <span class="nb">int</span> <span class="nd">@nn</span><span class="o">.</span><span class="n">compact</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">training</span><span class="p">:</span> <span class="nb">bool</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">features</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="mf">0.5</span><span class="p">,</span> <span class="n">deterministic</span><span class="o">=</span><span class="ow">not</span> <span class="n">training</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">x</span> <span class="k">class</span> <span class="nc">Model</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="n">dmid</span><span class="p">:</span> <span class="nb">int</span> <span class="n">dout</span><span class="p">:</span> <span class="nb">int</span> <span class="nd">@nn</span><span class="o">.</span><span class="n">compact</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">training</span><span class="p">:</span> <span class="nb">bool</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="n">Block</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dmid</span><span class="p">)(</span><span class="n">x</span><span class="p">,</span> <span class="n">training</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dout</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">x</span> </pre></div> </div> </div> </div> </p> <p>The <code class="docutils literal notranslate"><span class="pre">__call__</span></code> method looks very similar in both libraries, however, in Flax you have to use the <code class="docutils literal notranslate"><span class="pre">&#64;nn.compact</span></code> decorator in order to be able to define submodules inline. In Haiku, this is the default behavior.</p> <p>Now, a place where Haiku and Flax differ substantially is in how you construct the model. In Haiku, you use <code class="docutils literal notranslate"><span class="pre">hk.transform</span></code> over a function that calls your Module, <code class="docutils literal notranslate"><span class="pre">transform</span></code> will return an object with <code class="docutils literal notranslate"><span class="pre">init</span></code> and <code class="docutils literal notranslate"><span class="pre">apply</span></code> methods. In Flax, you simply instantiate your Module.</p> <p><div class="sd-tab-set docutils"> <input checked="checked" id="sd-tab-item-2" name="sd-tab-set-1" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="Haiku" for="sd-tab-item-2"> Haiku</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">training</span><span class="p">:</span> <span class="nb">bool</span><span class="p">):</span> <span class="k">return</span> <span class="n">Model</span><span class="p">(</span><span class="mi">256</span><span class="p">,</span> <span class="mi">10</span><span class="p">)(</span><span class="n">x</span><span class="p">,</span> <span class="n">training</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">hk</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">forward</span><span class="p">)</span> </pre></div> </div> </div> <input id="sd-tab-item-3" name="sd-tab-set-1" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="Flax" for="sd-tab-item-3"> Flax</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="o">...</span> <span class="n">model</span> <span class="o">=</span> <span class="n">Model</span><span class="p">(</span><span class="mi">256</span><span class="p">,</span> <span class="mi">10</span><span class="p">)</span> </pre></div> </div> </div> </div> </p> <p>To get the model parameters in both libraries you use the <code class="docutils literal notranslate"><span class="pre">init</span></code> method with a <code class="docutils literal notranslate"><span class="pre">random.key</span></code> plus some inputs to run the model. The main difference here is that Flax returns a mapping from collection names to nested array dictionaries, <code class="docutils literal notranslate"><span class="pre">params</span></code> is just one of these possible collections. In Haiku, you get the <code class="docutils literal notranslate"><span class="pre">params</span></code> structure directly.</p> <p><div class="sd-tab-set docutils"> <input checked="checked" id="sd-tab-item-4" name="sd-tab-set-2" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="Haiku" for="sd-tab-item-4"> Haiku</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">sample_x</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">numpy</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">784</span><span class="p">))</span> <span class="n">params</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">init</span><span class="p">(</span> <span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">sample_x</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">False</span> <span class="c1"># &lt;== inputs</span> <span class="p">)</span> <span class="o">...</span> </pre></div> </div> </div> <input id="sd-tab-item-5" name="sd-tab-set-2" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="Flax" for="sd-tab-item-5"> Flax</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">sample_x</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">numpy</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">784</span><span class="p">))</span> <span class="n">variables</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">init</span><span class="p">(</span> <span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">sample_x</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">False</span> <span class="c1"># &lt;== inputs</span> <span class="p">)</span> <span class="n">params</span> <span class="o">=</span> <span class="n">variables</span><span class="p">[</span><span class="s2">&quot;params&quot;</span><span class="p">]</span> </pre></div> </div> </div> </div> </p> <p>One very important thing to note is that in Flax the parameters structure is hierarchical, with one level per nested module and a final level for the parameter name. In Haiku the parameters structure is a python dictionary with a two level hierarchy: the fully qualified module name mapping to the parameter name. The module name consists of a <code class="docutils literal notranslate"><span class="pre">/</span></code> separated string path of all the nested Modules.</p> <div class="sd-tab-set docutils"> <input checked="checked" id="sd-tab-item-6" name="sd-tab-set-3" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="Haiku" for="sd-tab-item-6"> Haiku</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="o">...</span> <span class="p">{</span> <span class="s1">&#39;model/block/linear&#39;</span><span class="p">:</span> <span class="p">{</span> <span class="s1">&#39;b&#39;</span><span class="p">:</span> <span class="p">(</span><span class="mi">256</span><span class="p">,),</span> <span class="s1">&#39;w&#39;</span><span class="p">:</span> <span class="p">(</span><span class="mi">784</span><span class="p">,</span> <span class="mi">256</span><span class="p">),</span> <span class="p">},</span> <span class="s1">&#39;model/linear&#39;</span><span class="p">:</span> <span class="p">{</span> <span class="s1">&#39;b&#39;</span><span class="p">:</span> <span class="p">(</span><span class="mi">10</span><span class="p">,),</span> <span class="s1">&#39;w&#39;</span><span class="p">:</span> <span class="p">(</span><span class="mi">256</span><span class="p">,</span> <span class="mi">10</span><span class="p">),</span> <span class="p">}</span> <span class="p">}</span> <span class="o">...</span> </pre></div> </div> </div> <input id="sd-tab-item-7" name="sd-tab-set-3" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="Flax" for="sd-tab-item-7"> Flax</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">FrozenDict</span><span class="p">({</span> <span class="n">Block_0</span><span class="p">:</span> <span class="p">{</span> <span class="n">Dense_0</span><span class="p">:</span> <span class="p">{</span> <span class="n">bias</span><span class="p">:</span> <span class="p">(</span><span class="mi">256</span><span class="p">,),</span> <span class="n">kernel</span><span class="p">:</span> <span class="p">(</span><span class="mi">784</span><span class="p">,</span> <span class="mi">256</span><span class="p">),</span> <span class="p">},</span> <span class="p">},</span> <span class="n">Dense_0</span><span class="p">:</span> <span class="p">{</span> <span class="n">bias</span><span class="p">:</span> <span class="p">(</span><span class="mi">10</span><span class="p">,),</span> <span class="n">kernel</span><span class="p">:</span> <span class="p">(</span><span class="mi">256</span><span class="p">,</span> <span class="mi">10</span><span class="p">),</span> <span class="p">},</span> <span class="p">})</span> </pre></div> </div> </div> </div> <p>During training in both frameworks you pass the parameters structure to the <code class="docutils literal notranslate"><span class="pre">apply</span></code> method to run the forward pass. Since we are using dropout, in both cases we must provide a <code class="docutils literal notranslate"><span class="pre">key</span></code> to <code class="docutils literal notranslate"><span class="pre">apply</span></code> in order to generate the random dropout masks.</p> <p><div class="sd-tab-set docutils"> <input checked="checked" id="sd-tab-item-8" name="sd-tab-set-4" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="Haiku" for="sd-tab-item-8"> Haiku</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">train_step</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">params</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">labels</span><span class="p">):</span> <span class="k">def</span> <span class="nf">loss_fn</span><span class="p">(</span><span class="n">params</span><span class="p">):</span> <span class="n">logits</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span> <span class="n">params</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">True</span> <span class="c1"># &lt;== inputs</span> <span class="p">)</span> <span class="k">return</span> <span class="n">optax</span><span class="o">.</span><span class="n">softmax_cross_entropy_with_integer_labels</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">labels</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span> <span class="n">grads</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">grad</span><span class="p">(</span><span class="n">loss_fn</span><span class="p">)(</span><span class="n">params</span><span class="p">)</span> <span class="n">params</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">tree_util</span><span class="o">.</span><span class="n">tree_map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">p</span><span class="p">,</span> <span class="n">g</span><span class="p">:</span> <span class="n">p</span> <span class="o">-</span> <span class="mf">0.1</span> <span class="o">*</span> <span class="n">g</span><span class="p">,</span> <span class="n">params</span><span class="p">,</span> <span class="n">grads</span><span class="p">)</span> <span class="k">return</span> <span class="n">params</span> </pre></div> </div> </div> <input id="sd-tab-item-9" name="sd-tab-set-4" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="Flax" for="sd-tab-item-9"> Flax</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">train_step</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="n">params</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">labels</span><span class="p">):</span> <span class="k">def</span> <span class="nf">loss_fn</span><span class="p">(</span><span class="n">params</span><span class="p">):</span> <span class="n">logits</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span> <span class="p">{</span><span class="s1">&#39;params&#39;</span><span class="p">:</span> <span class="n">params</span><span class="p">},</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="c1"># &lt;== inputs</span> <span class="n">rngs</span><span class="o">=</span><span class="p">{</span><span class="s1">&#39;dropout&#39;</span><span class="p">:</span> <span class="n">key</span><span class="p">}</span> <span class="p">)</span> <span class="k">return</span> <span class="n">optax</span><span class="o">.</span><span class="n">softmax_cross_entropy_with_integer_labels</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">labels</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span> <span class="n">grads</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">grad</span><span class="p">(</span><span class="n">loss_fn</span><span class="p">)(</span><span class="n">params</span><span class="p">)</span> <span class="n">params</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">tree_util</span><span class="o">.</span><span class="n">tree_map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">p</span><span class="p">,</span> <span class="n">g</span><span class="p">:</span> <span class="n">p</span> <span class="o">-</span> <span class="mf">0.1</span> <span class="o">*</span> <span class="n">g</span><span class="p">,</span> <span class="n">params</span><span class="p">,</span> <span class="n">grads</span><span class="p">)</span> <span class="k">return</span> <span class="n">params</span> </pre></div> </div> </div> </div> </p> <p>The most notable differences is that in Flax you have to pass the parameters inside a dictionary with a <code class="docutils literal notranslate"><span class="pre">params</span></code> key, and the key inside a dictionary with a <code class="docutils literal notranslate"><span class="pre">dropout</span></code> key. This is because in Flax you can have many types of model state and random state. In Haiku, you just pass the parameters and the key directly.</p> </div> <div class="section" id="handling-state"> <h2>Handling State<a class="headerlink" href="#handling-state" title="Permalink to this heading">#</a></h2> <p>Now let’s see how mutable state is handled in both libraries. We will take the same model as before, but now we will replace Dropout with BatchNorm.</p> <p><div class="sd-tab-set docutils"> <input checked="checked" id="sd-tab-item-10" name="sd-tab-set-5" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="Haiku" for="sd-tab-item-10"> Haiku</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">Block</span><span class="p">(</span><span class="n">hk</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">features</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">features</span> <span class="o">=</span> <span class="n">features</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">training</span><span class="p">:</span> <span class="nb">bool</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="n">hk</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">features</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">hk</span><span class="o">.</span><span class="n">BatchNorm</span><span class="p">(</span> <span class="n">create_scale</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">create_offset</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">decay_rate</span><span class="o">=</span><span class="mf">0.99</span> <span class="p">)(</span><span class="n">x</span><span class="p">,</span> <span class="n">is_training</span><span class="o">=</span><span class="n">training</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">x</span> </pre></div> </div> </div> <input id="sd-tab-item-11" name="sd-tab-set-5" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="Flax" for="sd-tab-item-11"> Flax</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">Block</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="n">features</span><span class="p">:</span> <span class="nb">int</span> <span class="nd">@nn</span><span class="o">.</span><span class="n">compact</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">training</span><span class="p">:</span> <span class="nb">bool</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">features</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm</span><span class="p">(</span> <span class="n">momentum</span><span class="o">=</span><span class="mf">0.99</span> <span class="p">)(</span><span class="n">x</span><span class="p">,</span> <span class="n">use_running_average</span><span class="o">=</span><span class="ow">not</span> <span class="n">training</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">x</span> </pre></div> </div> </div> </div> </p> <p>The code is very similar in this case as both libraries provide a BatchNorm layer. The most notable difference is that Haiku uses <code class="docutils literal notranslate"><span class="pre">is_training</span></code> to control whether or not to update the running statistics, whereas Flax uses <code class="docutils literal notranslate"><span class="pre">use_running_average</span></code> for the same purpose.</p> <p>To instantiate a stateful model in Haiku you use <code class="docutils literal notranslate"><span class="pre">hk.transform_with_state</span></code>, which changes the signature for <code class="docutils literal notranslate"><span class="pre">init</span></code> and <code class="docutils literal notranslate"><span class="pre">apply</span></code> to accept and return state. As before, in Flax you construct the Module directly.</p> <p><div class="sd-tab-set docutils"> <input checked="checked" id="sd-tab-item-12" name="sd-tab-set-6" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="Haiku" for="sd-tab-item-12"> Haiku</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">training</span><span class="p">:</span> <span class="nb">bool</span><span class="p">):</span> <span class="k">return</span> <span class="n">Model</span><span class="p">(</span><span class="mi">256</span><span class="p">,</span> <span class="mi">10</span><span class="p">)(</span><span class="n">x</span><span class="p">,</span> <span class="n">training</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">hk</span><span class="o">.</span><span class="n">transform_with_state</span><span class="p">(</span><span class="n">forward</span><span class="p">)</span> </pre></div> </div> </div> <input id="sd-tab-item-13" name="sd-tab-set-6" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="Flax" for="sd-tab-item-13"> Flax</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="o">...</span> <span class="n">model</span> <span class="o">=</span> <span class="n">Model</span><span class="p">(</span><span class="mi">256</span><span class="p">,</span> <span class="mi">10</span><span class="p">)</span> </pre></div> </div> </div> </div> </p> <p>To initialize both the parameters and state you just call the <code class="docutils literal notranslate"><span class="pre">init</span></code> method as before. However, in Haiku you now get <code class="docutils literal notranslate"><span class="pre">state</span></code> as a second return value, and in Flax you get a new <code class="docutils literal notranslate"><span class="pre">batch_stats</span></code> collection in the <code class="docutils literal notranslate"><span class="pre">variables</span></code> dictionary. Note that since <code class="docutils literal notranslate"><span class="pre">hk.BatchNorm</span></code> only initializes batch statistics when <code class="docutils literal notranslate"><span class="pre">is_training=True</span></code>, we must set <code class="docutils literal notranslate"><span class="pre">training=True</span></code> when initializing parameters of a Haiku model with an <code class="docutils literal notranslate"><span class="pre">hk.BatchNorm</span></code> layer. In Flax, we can set <code class="docutils literal notranslate"><span class="pre">training=False</span></code> as usual.</p> <p><div class="sd-tab-set docutils"> <input checked="checked" id="sd-tab-item-14" name="sd-tab-set-7" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="Haiku" for="sd-tab-item-14"> Haiku</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">sample_x</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">numpy</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">784</span><span class="p">))</span> <span class="n">params</span><span class="p">,</span> <span class="n">state</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">init</span><span class="p">(</span> <span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="hll"> <span class="n">sample_x</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">True</span> <span class="c1"># &lt;== inputs</span> </span><span class="p">)</span> <span class="o">...</span> </pre></div> </div> </div> <input id="sd-tab-item-15" name="sd-tab-set-7" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="Flax" for="sd-tab-item-15"> Flax</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">sample_x</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">numpy</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">784</span><span class="p">))</span> <span class="n">variables</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">init</span><span class="p">(</span> <span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="hll"> <span class="n">sample_x</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">False</span> <span class="c1"># &lt;== inputs</span> </span><span class="p">)</span> <span class="n">params</span><span class="p">,</span> <span class="n">batch_stats</span> <span class="o">=</span> <span class="n">variables</span><span class="p">[</span><span class="s2">&quot;params&quot;</span><span class="p">],</span> <span class="n">variables</span><span class="p">[</span><span class="s2">&quot;batch_stats&quot;</span><span class="p">]</span> </pre></div> </div> </div> </div> </p> <p>In general, in Flax you might find other state collections in the <code class="docutils literal notranslate"><span class="pre">variables</span></code> dictionary such as <code class="docutils literal notranslate"><span class="pre">cache</span></code> for auto-regressive transformers models, <code class="docutils literal notranslate"><span class="pre">intermediates</span></code> for intermediate values added using <code class="docutils literal notranslate"><span class="pre">Module.sow</span></code>, or other collection names defined by custom layers. Haiku only makes a distinction between <code class="docutils literal notranslate"><span class="pre">params</span></code> (variables which do not change while running <code class="docutils literal notranslate"><span class="pre">apply</span></code>) and <code class="docutils literal notranslate"><span class="pre">state</span></code> (variables which can change while running <code class="docutils literal notranslate"><span class="pre">apply</span></code>).</p> <p>Now, training looks very similar in both frameworks as you use the same <code class="docutils literal notranslate"><span class="pre">apply</span></code> method to run the forward pass. In Haiku, now pass the <code class="docutils literal notranslate"><span class="pre">state</span></code> as the second argument to <code class="docutils literal notranslate"><span class="pre">apply</span></code>, and get the new state as the second return value. In Flax, you instead add <code class="docutils literal notranslate"><span class="pre">batch_stats</span></code> as a new key to the input dictionary, and get the <code class="docutils literal notranslate"><span class="pre">updates</span></code> variables dictionary as the second return value.</p> <p><div class="sd-tab-set docutils"> <input checked="checked" id="sd-tab-item-16" name="sd-tab-set-8" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="Haiku" for="sd-tab-item-16"> Haiku</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">train_step</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">state</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">labels</span><span class="p">):</span> <span class="k">def</span> <span class="nf">loss_fn</span><span class="p">(</span><span class="n">params</span><span class="p">):</span> <span class="n">logits</span><span class="p">,</span> <span class="n">new_state</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span> <span class="n">params</span><span class="p">,</span> <span class="n">state</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="c1"># &lt;== rng</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">True</span> <span class="c1"># &lt;== inputs</span> <span class="p">)</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">optax</span><span class="o">.</span><span class="n">softmax_cross_entropy_with_integer_labels</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">labels</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span> <span class="k">return</span> <span class="n">loss</span><span class="p">,</span> <span class="n">new_state</span> <span class="n">grads</span><span class="p">,</span> <span class="n">new_state</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">grad</span><span class="p">(</span><span class="n">loss_fn</span><span class="p">,</span> <span class="n">has_aux</span><span class="o">=</span><span class="kc">True</span><span class="p">)(</span><span class="n">params</span><span class="p">)</span> <span class="n">params</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">tree_util</span><span class="o">.</span><span class="n">tree_map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">p</span><span class="p">,</span> <span class="n">g</span><span class="p">:</span> <span class="n">p</span> <span class="o">-</span> <span class="mf">0.1</span> <span class="o">*</span> <span class="n">g</span><span class="p">,</span> <span class="n">params</span><span class="p">,</span> <span class="n">grads</span><span class="p">)</span> <span class="k">return</span> <span class="n">params</span><span class="p">,</span> <span class="n">new_state</span> </pre></div> </div> </div> <input id="sd-tab-item-17" name="sd-tab-set-8" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="Flax" for="sd-tab-item-17"> Flax</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">train_step</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">batch_stats</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">labels</span><span class="p">):</span> <span class="k">def</span> <span class="nf">loss_fn</span><span class="p">(</span><span class="n">params</span><span class="p">):</span> <span class="n">logits</span><span class="p">,</span> <span class="n">updates</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span> <span class="p">{</span><span class="s1">&#39;params&#39;</span><span class="p">:</span> <span class="n">params</span><span class="p">,</span> <span class="s1">&#39;batch_stats&#39;</span><span class="p">:</span> <span class="n">batch_stats</span><span class="p">},</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="c1"># &lt;== inputs</span> <span class="n">mutable</span><span class="o">=</span><span class="s1">&#39;batch_stats&#39;</span><span class="p">,</span> <span class="p">)</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">optax</span><span class="o">.</span><span class="n">softmax_cross_entropy_with_integer_labels</span><span class="p">(</span><span class="n">logits</span><span class="p">,</span> <span class="n">labels</span><span class="p">)</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span> <span class="k">return</span> <span class="n">loss</span><span class="p">,</span> <span class="n">updates</span><span class="p">[</span><span class="s2">&quot;batch_stats&quot;</span><span class="p">]</span> <span class="n">grads</span><span class="p">,</span> <span class="n">batch_stats</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">grad</span><span class="p">(</span><span class="n">loss_fn</span><span class="p">,</span> <span class="n">has_aux</span><span class="o">=</span><span class="kc">True</span><span class="p">)(</span><span class="n">params</span><span class="p">)</span> <span class="n">params</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">tree_util</span><span class="o">.</span><span class="n">tree_map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">p</span><span class="p">,</span> <span class="n">g</span><span class="p">:</span> <span class="n">p</span> <span class="o">-</span> <span class="mf">0.1</span> <span class="o">*</span> <span class="n">g</span><span class="p">,</span> <span class="n">params</span><span class="p">,</span> <span class="n">grads</span><span class="p">)</span> <span class="k">return</span> <span class="n">params</span><span class="p">,</span> <span class="n">batch_stats</span> </pre></div> </div> </div> </div> </p> <p>One major difference is that in Flax a state collection can be mutable or immutable. During <code class="docutils literal notranslate"><span class="pre">init</span></code> all collections are mutable by default, however, during <code class="docutils literal notranslate"><span class="pre">apply</span></code> you have to explicitly specify which collections are mutable. In this example, we specify that <code class="docutils literal notranslate"><span class="pre">batch_stats</span></code> is mutable. Here a single string is passed but a list can also be given if there are more mutable collections. If this is not done an error will be raised at runtime when trying to mutate <code class="docutils literal notranslate"><span class="pre">batch_stats</span></code>. Also, when <code class="docutils literal notranslate"><span class="pre">mutable</span></code> is anything other than <code class="docutils literal notranslate"><span class="pre">False</span></code>, the <code class="docutils literal notranslate"><span class="pre">updates</span></code> dictionary is returned as the second return value of <code class="docutils literal notranslate"><span class="pre">apply</span></code>, else only the model output is returned. Haiku makes the mutable/immutable distinction through having <code class="docutils literal notranslate"><span class="pre">params</span></code> (immutable) and <code class="docutils literal notranslate"><span class="pre">state</span></code> (mutable) and using either <code class="docutils literal notranslate"><span class="pre">hk.transform</span></code> or <code class="docutils literal notranslate"><span class="pre">hk.transform_with_state</span></code></p> </div> <div class="section" id="using-multiple-methods"> <h2>Using Multiple Methods<a class="headerlink" href="#using-multiple-methods" title="Permalink to this heading">#</a></h2> <p>In this section we will take a look at how to use multiple methods in Haiku and Flax. As an example, we will implement an auto-encoder model with three methods: <code class="docutils literal notranslate"><span class="pre">encode</span></code>, <code class="docutils literal notranslate"><span class="pre">decode</span></code>, and <code class="docutils literal notranslate"><span class="pre">__call__</span></code>.</p> <p>In Haiku, we can just define the submodules that <code class="docutils literal notranslate"><span class="pre">encode</span></code> and <code class="docutils literal notranslate"><span class="pre">decode</span></code> need directly in <code class="docutils literal notranslate"><span class="pre">__init__</span></code>, in this case each will just use a <code class="docutils literal notranslate"><span class="pre">Linear</span></code> layer. In Flax, we will define an <code class="docutils literal notranslate"><span class="pre">encoder</span></code> and a <code class="docutils literal notranslate"><span class="pre">decoder</span></code> Module ahead of time in <code class="docutils literal notranslate"><span class="pre">setup</span></code>, and use them in the <code class="docutils literal notranslate"><span class="pre">encode</span></code> and <code class="docutils literal notranslate"><span class="pre">decode</span></code> respectively.</p> <p><div class="sd-tab-set docutils"> <input checked="checked" id="sd-tab-item-18" name="sd-tab-set-9" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="Haiku" for="sd-tab-item-18"> Haiku</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">AutoEncoder</span><span class="p">(</span><span class="n">hk</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">output_dim</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span> <span class="o">=</span> <span class="n">hk</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;encoder&quot;</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span> <span class="o">=</span> <span class="n">hk</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">output_dim</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;decoder&quot;</span><span class="p">)</span> <span class="k">def</span> <span class="nf">encode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">def</span> <span class="nf">decode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">encode</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">decode</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">x</span> </pre></div> </div> </div> <input id="sd-tab-item-19" name="sd-tab-set-9" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="Flax" for="sd-tab-item-19"> Flax</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">AutoEncoder</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="n">embed_dim</span><span class="p">:</span> <span class="nb">int</span> <span class="n">output_dim</span><span class="p">:</span> <span class="nb">int</span> <span class="k">def</span> <span class="nf">setup</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">embed_dim</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">output_dim</span><span class="p">)</span> <span class="k">def</span> <span class="nf">encode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">def</span> <span class="nf">decode</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">encode</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">decode</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">x</span> </pre></div> </div> </div> </div> </p> <p>Note that in Flax <code class="docutils literal notranslate"><span class="pre">setup</span></code> doesn’t run after <code class="docutils literal notranslate"><span class="pre">__init__</span></code>, instead it runs when <code class="docutils literal notranslate"><span class="pre">init</span></code> or <code class="docutils literal notranslate"><span class="pre">apply</span></code> are called.</p> <p>Now, we want to be able to call any method from our <code class="docutils literal notranslate"><span class="pre">AutoEncoder</span></code> model. In Haiku we can define multiple <code class="docutils literal notranslate"><span class="pre">apply</span></code> methods for a module through <code class="docutils literal notranslate"><span class="pre">hk.multi_transform</span></code>. The function passed to <code class="docutils literal notranslate"><span class="pre">multi_transform</span></code> defines how to initialize the module and which different apply methods to generate.</p> <p><div class="sd-tab-set docutils"> <input checked="checked" id="sd-tab-item-20" name="sd-tab-set-10" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="Haiku" for="sd-tab-item-20"> Haiku</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">forward</span><span class="p">():</span> <span class="n">module</span> <span class="o">=</span> <span class="n">AutoEncoder</span><span class="p">(</span><span class="mi">256</span><span class="p">,</span> <span class="mi">784</span><span class="p">)</span> <span class="n">init</span> <span class="o">=</span> <span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">module</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">init</span><span class="p">,</span> <span class="p">(</span><span class="n">module</span><span class="o">.</span><span class="n">encode</span><span class="p">,</span> <span class="n">module</span><span class="o">.</span><span class="n">decode</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">hk</span><span class="o">.</span><span class="n">multi_transform</span><span class="p">(</span><span class="n">forward</span><span class="p">)</span> </pre></div> </div> </div> <input id="sd-tab-item-21" name="sd-tab-set-10" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="Flax" for="sd-tab-item-21"> Flax</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="o">...</span> <span class="n">model</span> <span class="o">=</span> <span class="n">AutoEncoder</span><span class="p">(</span><span class="mi">256</span><span class="p">,</span> <span class="mi">784</span><span class="p">)</span> </pre></div> </div> </div> </div> </p> <p>To initialize the parameters of our model, <code class="docutils literal notranslate"><span class="pre">init</span></code> can be used to trigger the <code class="docutils literal notranslate"><span class="pre">__call__</span></code> method, which uses both the <code class="docutils literal notranslate"><span class="pre">encode</span></code> and <code class="docutils literal notranslate"><span class="pre">decode</span></code> method. This will create all the necessary parameters for the model.</p> <p><div class="sd-tab-set docutils"> <input checked="checked" id="sd-tab-item-22" name="sd-tab-set-11" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="Haiku" for="sd-tab-item-22"> Haiku</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">params</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">init</span><span class="p">(</span> <span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">x</span><span class="o">=</span><span class="n">jax</span><span class="o">.</span><span class="n">numpy</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">784</span><span class="p">)),</span> <span class="p">)</span> <span class="o">...</span> </pre></div> </div> </div> <input id="sd-tab-item-23" name="sd-tab-set-11" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="Flax" for="sd-tab-item-23"> Flax</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">variables</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">init</span><span class="p">(</span> <span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">x</span><span class="o">=</span><span class="n">jax</span><span class="o">.</span><span class="n">numpy</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">784</span><span class="p">)),</span> <span class="p">)</span> <span class="n">params</span> <span class="o">=</span> <span class="n">variables</span><span class="p">[</span><span class="s2">&quot;params&quot;</span><span class="p">]</span> </pre></div> </div> </div> </div> </p> <p>This generates the following parameter structure.</p> <div class="sd-tab-set docutils"> <input checked="checked" id="sd-tab-item-24" name="sd-tab-set-12" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="Haiku" for="sd-tab-item-24"> Haiku</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="p">{</span> <span class="s1">&#39;auto_encoder/~/decoder&#39;</span><span class="p">:</span> <span class="p">{</span> <span class="s1">&#39;b&#39;</span><span class="p">:</span> <span class="p">(</span><span class="mi">784</span><span class="p">,),</span> <span class="s1">&#39;w&#39;</span><span class="p">:</span> <span class="p">(</span><span class="mi">256</span><span class="p">,</span> <span class="mi">784</span><span class="p">)</span> <span class="p">},</span> <span class="s1">&#39;auto_encoder/~/encoder&#39;</span><span class="p">:</span> <span class="p">{</span> <span class="s1">&#39;b&#39;</span><span class="p">:</span> <span class="p">(</span><span class="mi">256</span><span class="p">,),</span> <span class="s1">&#39;w&#39;</span><span class="p">:</span> <span class="p">(</span><span class="mi">784</span><span class="p">,</span> <span class="mi">256</span><span class="p">)</span> <span class="p">}</span> <span class="p">}</span> </pre></div> </div> </div> <input id="sd-tab-item-25" name="sd-tab-set-12" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="Flax" for="sd-tab-item-25"> Flax</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">FrozenDict</span><span class="p">({</span> <span class="n">decoder</span><span class="p">:</span> <span class="p">{</span> <span class="n">bias</span><span class="p">:</span> <span class="p">(</span><span class="mi">784</span><span class="p">,),</span> <span class="n">kernel</span><span class="p">:</span> <span class="p">(</span><span class="mi">256</span><span class="p">,</span> <span class="mi">784</span><span class="p">),</span> <span class="p">},</span> <span class="n">encoder</span><span class="p">:</span> <span class="p">{</span> <span class="n">bias</span><span class="p">:</span> <span class="p">(</span><span class="mi">256</span><span class="p">,),</span> <span class="n">kernel</span><span class="p">:</span> <span class="p">(</span><span class="mi">784</span><span class="p">,</span> <span class="mi">256</span><span class="p">),</span> <span class="p">},</span> <span class="p">})</span> </pre></div> </div> </div> </div> <p>Finally, let’s explore how we can employ the <code class="docutils literal notranslate"><span class="pre">apply</span></code> function to invoke the <code class="docutils literal notranslate"><span class="pre">encode</span></code> method:</p> <p><div class="sd-tab-set docutils"> <input checked="checked" id="sd-tab-item-26" name="sd-tab-set-13" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="Haiku" for="sd-tab-item-26"> Haiku</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">encode</span><span class="p">,</span> <span class="n">decode</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">apply</span> <span class="n">z</span> <span class="o">=</span> <span class="n">encode</span><span class="p">(</span> <span class="n">params</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="c1"># &lt;== rng</span> <span class="n">x</span><span class="o">=</span><span class="n">jax</span><span class="o">.</span><span class="n">numpy</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">784</span><span class="p">)),</span> <span class="p">)</span> </pre></div> </div> </div> <input id="sd-tab-item-27" name="sd-tab-set-13" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="Flax" for="sd-tab-item-27"> Flax</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="o">...</span> <span class="n">z</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span> <span class="p">{</span><span class="s2">&quot;params&quot;</span><span class="p">:</span> <span class="n">params</span><span class="p">},</span> <span class="n">x</span><span class="o">=</span><span class="n">jax</span><span class="o">.</span><span class="n">numpy</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">784</span><span class="p">)),</span> <span class="n">method</span><span class="o">=</span><span class="s2">&quot;encode&quot;</span><span class="p">,</span> <span class="p">)</span> </pre></div> </div> </div> </div> </p> <p>Because the Haiku <code class="docutils literal notranslate"><span class="pre">apply</span></code> function is generated through <code class="docutils literal notranslate"><span class="pre">hk.multi_transform</span></code>, it’s a tuple of two functions which we can unpack into an <code class="docutils literal notranslate"><span class="pre">encode</span></code> and <code class="docutils literal notranslate"><span class="pre">decode</span></code> function which correspond to the methods on the <code class="docutils literal notranslate"><span class="pre">AutoEncoder</span></code> module. In Flax we call the <code class="docutils literal notranslate"><span class="pre">encode</span></code> method through passing the method name as a string. Another noteworthy distinction here is that in Haiku, <code class="docutils literal notranslate"><span class="pre">rng</span></code> needs to be explicitly passed, even though the module does not use any stochastic operations during <code class="docutils literal notranslate"><span class="pre">apply</span></code>. In Flax this is not necessary (check out <a class="reference external" href="https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/rng_guide.html">Randomness and PRNGs in Flax</a>). The Haiku <code class="docutils literal notranslate"><span class="pre">rng</span></code> is set to <code class="docutils literal notranslate"><span class="pre">None</span></code> here, but you could also use <code class="docutils literal notranslate"><span class="pre">hk.without_apply_rng</span></code> on the <code class="docutils literal notranslate"><span class="pre">apply</span></code> function to remove the <code class="docutils literal notranslate"><span class="pre">rng</span></code> argument.</p> </div> <div class="section" id="lifted-transforms"> <h2>Lifted Transforms<a class="headerlink" href="#lifted-transforms" title="Permalink to this heading">#</a></h2> <p>Both Flax and Haiku provide a set of transforms, which we will refer to as lifted transforms, that wrap JAX transformations in such a way that they can be used with Modules and sometimes provide additional functionality. In this section we will take a look at how to use the lifted version of <code class="docutils literal notranslate"><span class="pre">scan</span></code> in both Flax and Haiku to implement a simple RNN layer.</p> <p>To begin, we will first define a <code class="docutils literal notranslate"><span class="pre">RNNCell</span></code> module that will contain the logic for a single step of the RNN. We will also define a <code class="docutils literal notranslate"><span class="pre">initial_state</span></code> method that will be used to initialize the state (a.k.a. <code class="docutils literal notranslate"><span class="pre">carry</span></code>) of the RNN. Like with <code class="docutils literal notranslate"><span class="pre">jax.lax.scan</span></code>, the <code class="docutils literal notranslate"><span class="pre">RNNCell.__call__</span></code> method will be a function that takes the carry and input, and returns the new carry and output. In this case, the carry and the output are the same.</p> <p><div class="sd-tab-set docutils"> <input checked="checked" id="sd-tab-item-28" name="sd-tab-set-14" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="Haiku" for="sd-tab-item-28"> Haiku</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">RNNCell</span><span class="p">(</span><span class="n">hk</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span> <span class="o">=</span> <span class="n">hidden_size</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">carry</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">concatenate</span><span class="p">([</span><span class="n">carry</span><span class="p">,</span> <span class="n">x</span><span class="p">],</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">hk</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">x</span><span class="p">,</span> <span class="n">x</span> <span class="k">def</span> <span class="nf">initial_state</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span> <span class="k">return</span> <span class="n">jnp</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span><span class="p">))</span> </pre></div> </div> </div> <input id="sd-tab-item-29" name="sd-tab-set-14" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="Flax" for="sd-tab-item-29"> Flax</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">RNNCell</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="n">hidden_size</span><span class="p">:</span> <span class="nb">int</span> <span class="nd">@nn</span><span class="o">.</span><span class="n">compact</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">carry</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">concatenate</span><span class="p">([</span><span class="n">carry</span><span class="p">,</span> <span class="n">x</span><span class="p">],</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">x</span><span class="p">,</span> <span class="n">x</span> <span class="k">def</span> <span class="nf">initial_state</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">):</span> <span class="k">return</span> <span class="n">jnp</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span><span class="p">))</span> </pre></div> </div> </div> </div> </p> <p>Next, we will define a <code class="docutils literal notranslate"><span class="pre">RNN</span></code> Module that will contain the logic for the entire RNN. In Haiku, we will first initialze the <code class="docutils literal notranslate"><span class="pre">RNNCell</span></code>, then use it to construct the <code class="docutils literal notranslate"><span class="pre">carry</span></code>, and finally use <code class="docutils literal notranslate"><span class="pre">hk.scan</span></code> to run the <code class="docutils literal notranslate"><span class="pre">RNNCell</span></code> over the input sequence. In Flax its done a bit differently, we will use <code class="docutils literal notranslate"><span class="pre">nn.scan</span></code> to define a new temporary type that wraps <code class="docutils literal notranslate"><span class="pre">RNNCell</span></code>. During this process we will also specify instruct <code class="docutils literal notranslate"><span class="pre">nn.scan</span></code> to broadcast the <code class="docutils literal notranslate"><span class="pre">params</span></code> collection (all steps share the same parameters) and to not split the <code class="docutils literal notranslate"><span class="pre">params</span></code> rng stream (so all steps intialize with the same parameters), and finally we will specify that we want scan to run over the second axis of the input and stack the outputs along the second axis as well. We will then use this temporary type immediately to create an instance of the lifted <code class="docutils literal notranslate"><span class="pre">RNNCell</span></code> and use it to create the <code class="docutils literal notranslate"><span class="pre">carry</span></code> and the run the <code class="docutils literal notranslate"><span class="pre">__call__</span></code> method which will <code class="docutils literal notranslate"><span class="pre">scan</span></code> over the sequence.</p> <p><div class="sd-tab-set docutils"> <input checked="checked" id="sd-tab-item-30" name="sd-tab-set-15" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="Haiku" for="sd-tab-item-30"> Haiku</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">RNN</span><span class="p">(</span><span class="n">hk</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">hidden_size</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span> <span class="o">=</span> <span class="n">hidden_size</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="n">cell</span> <span class="o">=</span> <span class="n">RNNCell</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span><span class="p">)</span> <span class="n">carry</span> <span class="o">=</span> <span class="n">cell</span><span class="o">.</span><span class="n">initial_state</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="n">carry</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">hk</span><span class="o">.</span><span class="n">scan</span><span class="p">(</span><span class="n">cell</span><span class="p">,</span> <span class="n">carry</span><span class="p">,</span> <span class="n">jnp</span><span class="o">.</span><span class="n">swapaxes</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span> <span class="n">y</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">swapaxes</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="k">return</span> <span class="n">y</span> </pre></div> </div> </div> <input id="sd-tab-item-31" name="sd-tab-set-15" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="Flax" for="sd-tab-item-31"> Flax</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">RNN</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="n">hidden_size</span><span class="p">:</span> <span class="nb">int</span> <span class="nd">@nn</span><span class="o">.</span><span class="n">compact</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="n">rnn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">scan</span><span class="p">(</span><span class="n">RNNCell</span><span class="p">,</span> <span class="n">variable_broadcast</span><span class="o">=</span><span class="s1">&#39;params&#39;</span><span class="p">,</span> <span class="n">split_rngs</span><span class="o">=</span><span class="p">{</span><span class="s1">&#39;params&#39;</span><span class="p">:</span> <span class="kc">False</span><span class="p">},</span> <span class="n">in_axes</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">out_axes</span><span class="o">=</span><span class="mi">1</span><span class="p">)(</span><span class="bp">self</span><span class="o">.</span><span class="n">hidden_size</span><span class="p">)</span> <span class="n">carry</span> <span class="o">=</span> <span class="n">rnn</span><span class="o">.</span><span class="n">initial_state</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="n">carry</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">rnn</span><span class="p">(</span><span class="n">carry</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">y</span> </pre></div> </div> </div> </div> </p> <p>In general, the main difference between lifted transforms between Flax and Haiku is that in Haiku the lifted transforms don’t operate over the state, that is, Haiku will handle the <code class="docutils literal notranslate"><span class="pre">params</span></code> and <code class="docutils literal notranslate"><span class="pre">state</span></code> in such a way that it keeps the same shape inside and outside of the transform. In Flax, the lifted transforms can operate over both variable collections and rng streams, the user must define how different collections are treated by each transform according to the transform’s semantics.</p> <p>Finally, let’s quickly view how the <code class="docutils literal notranslate"><span class="pre">RNN</span></code> Module would be used in both Haiku and Flax.</p> <p><div class="sd-tab-set docutils"> <input checked="checked" id="sd-tab-item-32" name="sd-tab-set-16" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="Haiku" for="sd-tab-item-32"> Haiku</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">x</span><span class="p">):</span> <span class="k">return</span> <span class="n">RNN</span><span class="p">(</span><span class="mi">64</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">hk</span><span class="o">.</span><span class="n">without_apply_rng</span><span class="p">(</span><span class="n">hk</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">forward</span><span class="p">))</span> <span class="n">params</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">init</span><span class="p">(</span> <span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">x</span><span class="o">=</span><span class="n">jax</span><span class="o">.</span><span class="n">numpy</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">3</span><span class="p">,</span> <span class="mi">12</span><span class="p">,</span> <span class="mi">32</span><span class="p">)),</span> <span class="p">)</span> <span class="n">y</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span> <span class="n">params</span><span class="p">,</span> <span class="n">x</span><span class="o">=</span><span class="n">jax</span><span class="o">.</span><span class="n">numpy</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">3</span><span class="p">,</span> <span class="mi">12</span><span class="p">,</span> <span class="mi">32</span><span class="p">)),</span> <span class="p">)</span> </pre></div> </div> </div> <input id="sd-tab-item-33" name="sd-tab-set-16" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="Flax" for="sd-tab-item-33"> Flax</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="o">...</span> <span class="n">model</span> <span class="o">=</span> <span class="n">RNN</span><span class="p">(</span><span class="mi">64</span><span class="p">)</span> <span class="n">variables</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">init</span><span class="p">(</span> <span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">x</span><span class="o">=</span><span class="n">jax</span><span class="o">.</span><span class="n">numpy</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">3</span><span class="p">,</span> <span class="mi">12</span><span class="p">,</span> <span class="mi">32</span><span class="p">)),</span> <span class="p">)</span> <span class="n">params</span> <span class="o">=</span> <span class="n">variables</span><span class="p">[</span><span class="s1">&#39;params&#39;</span><span class="p">]</span> <span class="n">y</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span> <span class="p">{</span><span class="s1">&#39;params&#39;</span><span class="p">:</span> <span class="n">params</span><span class="p">},</span> <span class="n">x</span><span class="o">=</span><span class="n">jax</span><span class="o">.</span><span class="n">numpy</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">3</span><span class="p">,</span> <span class="mi">12</span><span class="p">,</span> <span class="mi">32</span><span class="p">)),</span> <span class="p">)</span> </pre></div> </div> </div> </div> </p> <p>The only notable change with respect to the examples in the previous sections is that this time around we used <code class="docutils literal notranslate"><span class="pre">hk.without_apply_rng</span></code> in Haiku so we didn’t have to pass the <code class="docutils literal notranslate"><span class="pre">rng</span></code> argument as <code class="docutils literal notranslate"><span class="pre">None</span></code> to the <code class="docutils literal notranslate"><span class="pre">apply</span></code> method.</p> </div> <div class="section" id="scan-over-layers"> <h2>Scan over layers<a class="headerlink" href="#scan-over-layers" title="Permalink to this heading">#</a></h2> <p>One very important application of <code class="docutils literal notranslate"><span class="pre">scan</span></code> is apply a sequence of layers iteratively over an input, passing the output of each layer as the input to the next layer. This is very useful to reduce compilation time for big models. As an example we will create a simple <code class="docutils literal notranslate"><span class="pre">Block</span></code> Module, and then use it inside an <code class="docutils literal notranslate"><span class="pre">MLP</span></code> Module that will apply the <code class="docutils literal notranslate"><span class="pre">Block</span></code> Module <code class="docutils literal notranslate"><span class="pre">num_layers</span></code> times.</p> <p>In Haiku, we define the <code class="docutils literal notranslate"><span class="pre">Block</span></code> Module as usual, and then inside <code class="docutils literal notranslate"><span class="pre">MLP</span></code> we will use <code class="docutils literal notranslate"><span class="pre">hk.experimental.layer_stack</span></code> over a <code class="docutils literal notranslate"><span class="pre">stack_block</span></code> function to create a stack of <code class="docutils literal notranslate"><span class="pre">Block</span></code> Modules. In Flax, the definition of <code class="docutils literal notranslate"><span class="pre">Block</span></code> is a little different, <code class="docutils literal notranslate"><span class="pre">__call__</span></code> will accept and return a second dummy input/output that in both cases will be <code class="docutils literal notranslate"><span class="pre">None</span></code>. In <code class="docutils literal notranslate"><span class="pre">MLP</span></code>, we will use <code class="docutils literal notranslate"><span class="pre">nn.scan</span></code> as in the previous example, but by setting <code class="docutils literal notranslate"><span class="pre">split_rngs={'params':</span> <span class="pre">True}</span></code> and <code class="docutils literal notranslate"><span class="pre">variable_axes={'params':</span> <span class="pre">0}</span></code> we are telling <code class="docutils literal notranslate"><span class="pre">nn.scan</span></code> create different parameters for each step and slice the <code class="docutils literal notranslate"><span class="pre">params</span></code> collection along the first axis, effectively implementing a stack of <code class="docutils literal notranslate"><span class="pre">Block</span></code> Modules as in Haiku.</p> <p><div class="sd-tab-set docutils"> <input checked="checked" id="sd-tab-item-34" name="sd-tab-set-17" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="Haiku" for="sd-tab-item-34"> Haiku</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">Block</span><span class="p">(</span><span class="n">hk</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">features</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">features</span> <span class="o">=</span> <span class="n">features</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">training</span><span class="p">:</span> <span class="nb">bool</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="n">hk</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">features</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">hk</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">hk</span><span class="o">.</span><span class="n">next_rng_key</span><span class="p">(),</span> <span class="mf">0.5</span> <span class="k">if</span> <span class="n">training</span> <span class="k">else</span> <span class="mi">0</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">x</span> <span class="k">class</span> <span class="nc">MLP</span><span class="p">(</span><span class="n">hk</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">features</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">num_layers</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">features</span> <span class="o">=</span> <span class="n">features</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span> <span class="o">=</span> <span class="n">num_layers</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">training</span><span class="p">:</span> <span class="nb">bool</span><span class="p">):</span> <span class="nd">@hk</span><span class="o">.</span><span class="n">experimental</span><span class="o">.</span><span class="n">layer_stack</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">)</span> <span class="k">def</span> <span class="nf">stack_block</span><span class="p">(</span><span class="n">x</span><span class="p">):</span> <span class="k">return</span> <span class="n">Block</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">features</span><span class="p">)(</span><span class="n">x</span><span class="p">,</span> <span class="n">training</span><span class="p">)</span> <span class="n">stack</span> <span class="o">=</span> <span class="n">hk</span><span class="o">.</span><span class="n">experimental</span><span class="o">.</span><span class="n">layer_stack</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">)</span> <span class="k">return</span> <span class="n">stack_block</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> </pre></div> </div> </div> <input id="sd-tab-item-35" name="sd-tab-set-17" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="Flax" for="sd-tab-item-35"> Flax</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">Block</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="n">features</span><span class="p">:</span> <span class="nb">int</span> <span class="n">training</span><span class="p">:</span> <span class="nb">bool</span> <span class="nd">@nn</span><span class="o">.</span><span class="n">compact</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">_</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">features</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="mf">0.5</span><span class="p">)(</span><span class="n">x</span><span class="p">,</span> <span class="n">deterministic</span><span class="o">=</span><span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">training</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">x</span><span class="p">,</span> <span class="kc">None</span> <span class="k">class</span> <span class="nc">MLP</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="n">features</span><span class="p">:</span> <span class="nb">int</span> <span class="n">num_layers</span><span class="p">:</span> <span class="nb">int</span> <span class="nd">@nn</span><span class="o">.</span><span class="n">compact</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">training</span><span class="p">:</span> <span class="nb">bool</span><span class="p">):</span> <span class="n">ScanBlock</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">scan</span><span class="p">(</span> <span class="n">Block</span><span class="p">,</span> <span class="n">variable_axes</span><span class="o">=</span><span class="p">{</span><span class="s1">&#39;params&#39;</span><span class="p">:</span> <span class="mi">0</span><span class="p">},</span> <span class="n">split_rngs</span><span class="o">=</span><span class="p">{</span><span class="s1">&#39;params&#39;</span><span class="p">:</span> <span class="kc">True</span><span class="p">},</span> <span class="n">length</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span><span class="p">)</span> <span class="n">y</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">ScanBlock</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">features</span><span class="p">,</span> <span class="n">training</span><span class="p">)(</span><span class="n">x</span><span class="p">,</span> <span class="kc">None</span><span class="p">)</span> <span class="k">return</span> <span class="n">y</span> </pre></div> </div> </div> </div> </p> <p>Notice how in Flax we pass <code class="docutils literal notranslate"><span class="pre">None</span></code> as the second argument to <code class="docutils literal notranslate"><span class="pre">ScanBlock</span></code> and ignore its second output. These represent the inputs/outputs per-step but they are <code class="docutils literal notranslate"><span class="pre">None</span></code> because in this case we don’t have any.</p> <p>Initializing each model is the same as in previous examples. In this case, we will be specifying that we want to use <code class="docutils literal notranslate"><span class="pre">5</span></code> layers each with <code class="docutils literal notranslate"><span class="pre">64</span></code> features.</p> <p><div class="sd-tab-set docutils"> <input checked="checked" id="sd-tab-item-36" name="sd-tab-set-18" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="Haiku" for="sd-tab-item-36"> Haiku</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">training</span><span class="p">:</span> <span class="nb">bool</span><span class="p">):</span> <span class="k">return</span> <span class="n">MLP</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="n">num_layers</span><span class="o">=</span><span class="mi">5</span><span class="p">)(</span><span class="n">x</span><span class="p">,</span> <span class="n">training</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">hk</span><span class="o">.</span><span class="n">transform</span><span class="p">(</span><span class="n">forward</span><span class="p">)</span> <span class="n">sample_x</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">numpy</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">64</span><span class="p">))</span> <span class="n">params</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">init</span><span class="p">(</span> <span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">sample_x</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">False</span> <span class="c1"># &lt;== inputs</span> <span class="p">)</span> <span class="o">...</span> </pre></div> </div> </div> <input id="sd-tab-item-37" name="sd-tab-set-18" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="Flax" for="sd-tab-item-37"> Flax</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="o">...</span> <span class="n">model</span> <span class="o">=</span> <span class="n">MLP</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="n">num_layers</span><span class="o">=</span><span class="mi">5</span><span class="p">)</span> <span class="n">sample_x</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">numpy</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">64</span><span class="p">))</span> <span class="n">variables</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">init</span><span class="p">(</span> <span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">sample_x</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">False</span> <span class="c1"># &lt;== inputs</span> <span class="p">)</span> <span class="n">params</span> <span class="o">=</span> <span class="n">variables</span><span class="p">[</span><span class="s1">&#39;params&#39;</span><span class="p">]</span> </pre></div> </div> </div> </div> </p> <p>When using scan over layers the one thing you should notice is that all layers are fused into a single layer whose parameters have an extra “layer” dimension on the first axis. In this case, the shape of all parameters will start with <code class="docutils literal notranslate"><span class="pre">(5,</span> <span class="pre">...)</span></code> as we are using <code class="docutils literal notranslate"><span class="pre">5</span></code> layers.</p> <div class="sd-tab-set docutils"> <input checked="checked" id="sd-tab-item-38" name="sd-tab-set-19" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="Haiku" for="sd-tab-item-38"> Haiku</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="o">...</span> <span class="p">{</span> <span class="s1">&#39;mlp/__layer_stack_no_per_layer/block/linear&#39;</span><span class="p">:</span> <span class="p">{</span> <span class="s1">&#39;b&#39;</span><span class="p">:</span> <span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">64</span><span class="p">),</span> <span class="s1">&#39;w&#39;</span><span class="p">:</span> <span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">64</span><span class="p">,</span> <span class="mi">64</span><span class="p">)</span> <span class="p">}</span> <span class="p">}</span> <span class="o">...</span> </pre></div> </div> </div> <input id="sd-tab-item-39" name="sd-tab-set-19" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="Flax" for="sd-tab-item-39"> Flax</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">FrozenDict</span><span class="p">({</span> <span class="n">ScanBlock_0</span><span class="p">:</span> <span class="p">{</span> <span class="n">Dense_0</span><span class="p">:</span> <span class="p">{</span> <span class="n">bias</span><span class="p">:</span> <span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">64</span><span class="p">),</span> <span class="n">kernel</span><span class="p">:</span> <span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">64</span><span class="p">,</span> <span class="mi">64</span><span class="p">),</span> <span class="p">},</span> <span class="p">},</span> <span class="p">})</span> </pre></div> </div> </div> </div> </div> <div class="section" id="top-level-haiku-functions-vs-top-level-flax-modules"> <h2>Top-level Haiku functions vs top-level Flax modules<a class="headerlink" href="#top-level-haiku-functions-vs-top-level-flax-modules" title="Permalink to this heading">#</a></h2> <p>In Haiku, it is possible to write the entire model as a single function by using the raw <code class="docutils literal notranslate"><span class="pre">hk.{get,set}_{parameter,state}</span></code> to define/access model parameters and states. It very common to write the top-level “Module” as a function instead:</p> <p>The Flax team recommends a more Module-centric approach that uses <cite>__call__</cite> to define the forward function. The corresponding accessor will be <cite>nn.module.param</cite> and <cite>nn.module.variable</cite> (go to <a class="reference external" href="#handling-state">Handling State</a> for an explanaion on collections).</p> <p><div class="sd-tab-set docutils"> <input checked="checked" id="sd-tab-item-40" name="sd-tab-set-20" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="Haiku" for="sd-tab-item-40"> Haiku</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">forward</span><span class="p">(</span><span class="n">x</span><span class="p">):</span> <span class="n">counter</span> <span class="o">=</span> <span class="n">hk</span><span class="o">.</span><span class="n">get_state</span><span class="p">(</span><span class="s1">&#39;counter&#39;</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">[],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">jnp</span><span class="o">.</span><span class="n">int32</span><span class="p">,</span> <span class="n">init</span><span class="o">=</span><span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">)</span> <span class="n">multiplier</span> <span class="o">=</span> <span class="n">hk</span><span class="o">.</span><span class="n">get_parameter</span><span class="p">(</span><span class="s1">&#39;multiplier&#39;</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">[</span><span class="mi">1</span><span class="p">,],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">x</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="n">init</span><span class="o">=</span><span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">)</span> <span class="n">output</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">multiplier</span> <span class="o">*</span> <span class="n">counter</span> <span class="n">hk</span><span class="o">.</span><span class="n">set_state</span><span class="p">(</span><span class="s2">&quot;counter&quot;</span><span class="p">,</span> <span class="n">counter</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="k">return</span> <span class="n">output</span> <span class="n">model</span> <span class="o">=</span> <span class="n">hk</span><span class="o">.</span><span class="n">transform_with_state</span><span class="p">(</span><span class="n">forward</span><span class="p">)</span> <span class="n">params</span><span class="p">,</span> <span class="n">state</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">jax</span><span class="o">.</span><span class="n">numpy</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">64</span><span class="p">)))</span> </pre></div> </div> </div> <input id="sd-tab-item-41" name="sd-tab-set-20" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="Flax" for="sd-tab-item-41"> Flax</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span> <span class="nc">FooModule</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="nd">@nn</span><span class="o">.</span><span class="n">compact</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="n">counter</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">variable</span><span class="p">(</span><span class="s1">&#39;counter&#39;</span><span class="p">,</span> <span class="s1">&#39;count&#39;</span><span class="p">,</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">jnp</span><span class="o">.</span><span class="n">ones</span><span class="p">((),</span> <span class="n">jnp</span><span class="o">.</span><span class="n">int32</span><span class="p">))</span> <span class="n">multiplier</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">param</span><span class="p">(</span><span class="s1">&#39;multiplier&#39;</span><span class="p">,</span> <span class="n">nn</span><span class="o">.</span><span class="n">initializers</span><span class="o">.</span><span class="n">ones_init</span><span class="p">(),</span> <span class="p">[</span><span class="mi">1</span><span class="p">,],</span> <span class="n">x</span><span class="o">.</span><span class="n">dtype</span><span class="p">)</span> <span class="n">output</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">multiplier</span> <span class="o">*</span> <span class="n">counter</span><span class="o">.</span><span class="n">value</span> <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">is_initializing</span><span class="p">():</span> <span class="c1"># otherwise model.init() also increases it</span> <span class="n">counter</span><span class="o">.</span><span class="n">value</span> <span class="o">+=</span> <span class="mi">1</span> <span class="k">return</span> <span class="n">output</span> <span class="n">model</span> <span class="o">=</span> <span class="n">FooModule</span><span class="p">()</span> <span class="n">variables</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">),</span> <span class="n">jax</span><span class="o">.</span><span class="n">numpy</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">64</span><span class="p">)))</span> <span class="n">params</span><span class="p">,</span> <span class="n">counter</span> <span class="o">=</span> <span class="n">variables</span><span class="p">[</span><span class="s1">&#39;params&#39;</span><span class="p">],</span> <span class="n">variables</span><span class="p">[</span><span class="s1">&#39;counter&#39;</span><span class="p">]</span> </pre></div> </div> </div> </div> </p> </div> </div> </article> <footer class="prev-next-footer d-print-none"> <div class="prev-next-area"> <a class="left-prev" href="index.html" title="previous page"> <i class="fa-solid fa-angle-left"></i> <div class="prev-next-info"> <p class="prev-next-subtitle">previous</p> <p class="prev-next-title">Converting and upgrading</p> </div> </a> <a class="right-next" href="convert_pytorch_to_flax.html" title="next page"> <div class="prev-next-info"> <p class="prev-next-subtitle">next</p> <p class="prev-next-title">Convert PyTorch models to Flax</p> </div> <i class="fa-solid fa-angle-right"></i> </a> </div> </footer> </div> <div class="bd-sidebar-secondary bd-toc"><div class="sidebar-secondary-items sidebar-secondary__inner"> <div class="sidebar-secondary-item"> <div class="page-toc tocsection onthispage"> <i class="fa-solid fa-list"></i> Contents </div> <nav class="bd-toc-nav page-toc"> <ul class="visible nav section-nav flex-column"> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#basic-example">Basic Example</a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#handling-state">Handling State</a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#using-multiple-methods">Using Multiple Methods</a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#lifted-transforms">Lifted Transforms</a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#scan-over-layers">Scan over layers</a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#top-level-haiku-functions-vs-top-level-flax-modules">Top-level Haiku functions vs top-level Flax modules</a></li> </ul> </nav></div> </div></div> </div> <footer class="bd-footer-content"> <div class="bd-footer-content__inner container"> <div class="footer-item"> <p class="component-author"> By The Flax authors </p> </div> <div class="footer-item"> <p class="copyright"> © Copyright 2023, The Flax authors. <br/> </p> </div> <div class="footer-item"> </div> <div class="footer-item"> </div> </div> </footer> </main> </div> </div> <!-- Scripts loaded after <body> so the DOM is not blocked --> <script src="../../_static/scripts/bootstrap.js?digest=dfe6caa3a7d634c4db9b"></script> <script src="../../_static/scripts/pydata-sphinx-theme.js?digest=dfe6caa3a7d634c4db9b"></script> <footer class="bd-footer"> </footer> </body> </html>

Pages: 1 2 3 4 5 6 7 8 9 10