CINXE.COM
Upgrading my codebase to Optax
<!DOCTYPE html> <html lang="en" data-content_root="" > <head> <meta charset="utf-8" /> <meta name="viewport" content="width=device-width, initial-scale=1.0" /> <title>Upgrading my codebase to Optax</title> <script data-cfasync="false"> document.documentElement.dataset.mode = localStorage.getItem("mode") || ""; document.documentElement.dataset.theme = localStorage.getItem("theme") || ""; </script> <!-- Loaded before other Sphinx assets --> <link href="../../_static/styles/theme.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" /> <link href="../../_static/styles/bootstrap.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" /> <link href="../../_static/styles/pydata-sphinx-theme.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" /> <link href="../../_static/vendor/fontawesome/6.5.2/css/all.min.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" /> <link rel="preload" as="font" type="font/woff2" crossorigin href="../../_static/vendor/fontawesome/6.5.2/webfonts/fa-solid-900.woff2" /> <link rel="preload" as="font" type="font/woff2" crossorigin href="../../_static/vendor/fontawesome/6.5.2/webfonts/fa-brands-400.woff2" /> <link rel="preload" as="font" type="font/woff2" crossorigin href="../../_static/vendor/fontawesome/6.5.2/webfonts/fa-regular-400.woff2" /> <link rel="stylesheet" type="text/css" href="../../_static/pygments.css" /> <link rel="stylesheet" type="text/css" href="../../_static/styles/sphinx-book-theme.css" /> <link rel="stylesheet" type="text/css" href="../../_static/mystnb.4510f1fc1dee50b3e5859aac5469c37c29e427902b24a333a5f9fcb2f0b3ac41.css" /> <link rel="stylesheet" type="text/css" href="../../_static/sphinx-design.5ea377869091fd0449014c60fc090103.min.css" /> <link rel="stylesheet" type="text/css" href="../../_static/css/flax_theme.css" /> <!-- Pre-loaded scripts that we'll load fully later --> <link rel="preload" as="script" href="../../_static/scripts/bootstrap.js?digest=dfe6caa3a7d634c4db9b" /> <link rel="preload" as="script" href="../../_static/scripts/pydata-sphinx-theme.js?digest=dfe6caa3a7d634c4db9b" /> <script src="../../_static/vendor/fontawesome/6.5.2/js/all.min.js?digest=dfe6caa3a7d634c4db9b"></script> <script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script> <script src="../../_static/jquery.js"></script> <script src="../../_static/underscore.js"></script> <script src="../../_static/_sphinx_javascript_frameworks_compat.js"></script> <script src="../../_static/doctools.js"></script> <script src="../../_static/sphinx_highlight.js"></script> <script src="../../_static/scripts/sphinx-book-theme.js"></script> <script src="../../_static/design-tabs.js"></script> <script>DOCUMENTATION_OPTIONS.pagename = 'guides/converting_and_upgrading/optax_update_guide';</script> <link rel="shortcut icon" href="../../_static/flax.png"/> <link rel="index" title="Index" href="../../genindex.html" /> <link rel="search" title="Search" href="../../search.html" /> <link rel="next" title="Upgrading my codebase to Linen" href="linen_upgrade_guide.html" /> <link rel="prev" title="Migrate checkpointing to Orbax" href="orbax_upgrade_guide.html" /> <meta name="viewport" content="width=device-width, initial-scale=1"/> <meta name="docsearch:language" content="en"/> <script async type="text/javascript" src="/_/static/javascript/readthedocs-addons.js"></script><meta name="readthedocs-project-slug" content="flax-linen" /><meta name="readthedocs-version-slug" content="latest" /><meta name="readthedocs-resolver-filename" content="/guides/converting_and_upgrading/optax_update_guide.html" /><meta name="readthedocs-http-status" content="200" /></head> <body data-bs-spy="scroll" data-bs-target=".bd-toc-nav" data-offset="180" data-bs-root-margin="0px 0px -60%" data-default-mode=""> <div id="pst-skip-link" class="skip-link d-print-none"><a href="#main-content">Skip to main content</a></div> <div id="pst-scroll-pixel-helper"></div> <button type="button" class="btn rounded-pill" id="pst-back-to-top"> <i class="fa-solid fa-arrow-up"></i>Back to top</button> <input type="checkbox" class="sidebar-toggle" id="pst-primary-sidebar-checkbox"/> <label class="overlay overlay-primary" for="pst-primary-sidebar-checkbox"></label> <input type="checkbox" class="sidebar-toggle" id="pst-secondary-sidebar-checkbox"/> <label class="overlay overlay-secondary" for="pst-secondary-sidebar-checkbox"></label> <div class="search-button__wrapper"> <div class="search-button__overlay"></div> <div class="search-button__search-container"> <form class="bd-search d-flex align-items-center" action="../../search.html" method="get"> <i class="fa-solid fa-magnifying-glass"></i> <input type="search" class="form-control" name="q" id="search-input" placeholder="Search..." aria-label="Search..." autocomplete="off" autocorrect="off" autocapitalize="off" spellcheck="false"/> <span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd>K</kbd></span> </form></div> </div> <div class="pst-async-banner-revealer d-none"> <aside id="bd-header-version-warning" class="d-none d-print-none" aria-label="Version warning"></aside> </div> <aside class="bd-header-announcement" aria-label="Announcement"> <div class="bd-header-announcement__content"> <a href="https://flax.readthedocs.io/en/latest/index.html" style="text-decoration: none; color: white;" > This site covers the old Flax Linen API. <span style="color: lightgray;">[Explore the new <b>Flax NNX</b> API ✨]</span> </a> </div> </aside> <header class="bd-header navbar navbar-expand-lg bd-navbar d-print-none"> </header> <div class="bd-container"> <div class="bd-container__inner bd-page-width"> <div class="bd-sidebar-primary bd-sidebar"> <div class="sidebar-header-items sidebar-primary__section"> </div> <div class="sidebar-primary-items__start sidebar-primary__section"> <div class="sidebar-primary-item"> <a class="navbar-brand logo" href="../../index.html"> <img src="../../_static/flax.png" class="logo__image only-light" alt=" - Home"/> <script>document.write(`<img src="../../_static/flax.png" class="logo__image only-dark" alt=" - Home"/>`);</script> </a></div> <div class="sidebar-primary-item"> <script> document.write(` <button class="btn search-button-field search-button__button" title="Search" aria-label="Search" data-bs-placement="bottom" data-bs-toggle="tooltip"> <i class="fa-solid fa-magnifying-glass"></i> <span class="search-button__default-text">Search</span> <span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd class="kbd-shortcut__modifier">K</kbd></span> </button> `); </script></div> <div class="sidebar-primary-item"><nav class="bd-links bd-docs-nav" aria-label="Main"> <div class="bd-toc-item navbar-nav active"> <ul class="current nav bd-sidenav"> <li class="toctree-l1"><a class="reference internal" href="../../quick_start.html">Quick start</a></li> <li class="toctree-l1"><a class="reference internal" href="../flax_fundamentals/flax_basics.html">Flax Basics</a></li> <li class="toctree-l1 current active has-children"><a class="reference internal" href="../index.html">Guides</a><details open="open"><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul class="current"> <li class="toctree-l2 has-children"><a class="reference internal" href="../flax_fundamentals/index.html">Flax fundamentals</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l3"><a class="reference external" href="https://jax.readthedocs.io/en/latest/jax-101/index.html">JAX 101</a></li> <li class="toctree-l3"><a class="reference internal" href="../flax_fundamentals/flax_basics.html">Flax Basics</a></li> <li class="toctree-l3"><a class="reference internal" href="../flax_fundamentals/state_params.html">Managing Parameters and State</a></li> <li class="toctree-l3"><a class="reference internal" href="../flax_fundamentals/setup_or_nncompact.html"><code class="docutils literal notranslate"><span class="pre">setup</span></code> vs <code class="docutils literal notranslate"><span class="pre">compact</span></code></a></li> <li class="toctree-l3"><a class="reference internal" href="../flax_fundamentals/arguments.html">Dealing with Flax Module arguments</a></li> <li class="toctree-l3"><a class="reference internal" href="../flax_fundamentals/rng_guide.html">Randomness and PRNGs in Flax</a></li> </ul> </details></li> <li class="toctree-l2 has-children"><a class="reference internal" href="../data_preprocessing/index.html">Data preprocessing</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l3"><a class="reference internal" href="../data_preprocessing/full_eval.html">Processing the entire Dataset</a></li> <li class="toctree-l3"><a class="reference internal" href="../data_preprocessing/loading_datasets.html">Loading datasets</a></li> </ul> </details></li> <li class="toctree-l2 has-children"><a class="reference internal" href="../training_techniques/index.html">Training techniques</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l3"><a class="reference internal" href="../training_techniques/batch_norm.html">Batch normalization</a></li> <li class="toctree-l3"><a class="reference internal" href="../training_techniques/dropout.html">Dropout</a></li> <li class="toctree-l3"><a class="reference internal" href="../training_techniques/lr_schedule.html">Learning rate scheduling</a></li> <li class="toctree-l3"><a class="reference internal" href="../training_techniques/transfer_learning.html">Transfer learning</a></li> <li class="toctree-l3"><a class="reference internal" href="../training_techniques/use_checkpointing.html">Save and load checkpoints</a></li> </ul> </details></li> <li class="toctree-l2 has-children"><a class="reference internal" href="../parallel_training/index.html">Parallel training</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l3"><a class="reference internal" href="../parallel_training/ensembling.html">Ensembling on multiple devices</a></li> <li class="toctree-l3"><a class="reference internal" href="../parallel_training/flax_on_pjit.html">Scale up Flax Modules on multiple devices</a></li> </ul> </details></li> <li class="toctree-l2 has-children"><a class="reference internal" href="../model_inspection/index.html">Model inspection</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l3"><a class="reference internal" href="../model_inspection/model_surgery.html">Model surgery</a></li> <li class="toctree-l3"><a class="reference internal" href="../model_inspection/extracting_intermediates.html">Extracting intermediate values</a></li> </ul> </details></li> <li class="toctree-l2 current active has-children"><a class="reference internal" href="index.html">Converting and upgrading</a><details open="open"><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul class="current"> <li class="toctree-l3"><a class="reference internal" href="haiku_migration_guide.html">Migrating from Haiku to Flax</a></li> <li class="toctree-l3"><a class="reference internal" href="convert_pytorch_to_flax.html">Convert PyTorch models to Flax</a></li> <li class="toctree-l3"><a class="reference internal" href="orbax_upgrade_guide.html">Migrate checkpointing to Orbax</a></li> <li class="toctree-l3 current active"><a class="current reference internal" href="#">Upgrading my codebase to Optax</a></li> <li class="toctree-l3"><a class="reference internal" href="linen_upgrade_guide.html">Upgrading my codebase to Linen</a></li> <li class="toctree-l3"><a class="reference internal" href="rnncell_upgrade_guide.html">RNNCellBase Upgrade Guide</a></li> <li class="toctree-l3"><a class="reference internal" href="regular_dict_upgrade_guide.html">Migrate to regular dicts</a></li> </ul> </details></li> <li class="toctree-l2 has-children"><a class="reference internal" href="../quantization/index.html">Quantization</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l3"><a class="reference internal" href="../quantization/fp8_basics.html">User Guide on Using FP8</a></li> </ul> </details></li> <li class="toctree-l2"><a class="reference internal" href="../flax_sharp_bits.html">The Sharp Bits</a></li> </ul> </details></li> <li class="toctree-l1 has-children"><a class="reference internal" href="../../examples/index.html">Examples</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l2"><a class="reference internal" href="../../examples/core_examples.html">Core examples</a></li> <li class="toctree-l2"><a class="reference internal" href="../../examples/google_research_examples.html">Google Research examples</a></li> <li class="toctree-l2"><a class="reference internal" href="../../examples/repositories_that_use_flax.html">Repositories that use Flax</a></li> <li class="toctree-l2"><a class="reference internal" href="../../examples/community_examples.html">Community examples</a></li> </ul> </details></li> <li class="toctree-l1"><a class="reference internal" href="../../glossary.html">Glossary</a></li> <li class="toctree-l1"><a class="reference internal" href="../../faq.html">Frequently Asked Questions (FAQ)</a></li> <li class="toctree-l1 has-children"><a class="reference internal" href="../../developer_notes/index.html">Developer notes</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l2"><a class="reference internal" href="../../developer_notes/module_lifecycle.html">The Flax Module lifecycle</a></li> <li class="toctree-l2"><a class="reference internal" href="../../developer_notes/lift.html">Lifted transformations</a></li> <li class="toctree-l2"><a class="reference external" href="https://github.com/google/flax/tree/main/docs/flip">FLIPs</a></li> </ul> </details></li> <li class="toctree-l1"><a class="reference internal" href="../../philosophy.html">The Flax philosophy</a></li> <li class="toctree-l1"><a class="reference internal" href="../../contributing.html">How to contribute</a></li> <li class="toctree-l1 has-children"><a class="reference internal" href="../../api_reference/index.html">API Reference</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l2"><a class="reference internal" href="../../api_reference/flax.config.html">flax.config package</a></li> <li class="toctree-l2"><a class="reference internal" href="../../api_reference/flax.core.frozen_dict.html">flax.core.frozen_dict package</a></li> <li class="toctree-l2"><a class="reference internal" href="../../api_reference/flax.cursor.html">flax.cursor package</a></li> <li class="toctree-l2"><a class="reference internal" href="../../api_reference/flax.errors.html">flax.errors package</a></li> <li class="toctree-l2"><a class="reference internal" href="../../api_reference/flax.jax_utils.html">flax.jax_utils package</a></li> <li class="toctree-l2 has-children"><a class="reference internal" href="../../api_reference/flax.linen/index.html">flax.linen</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l3"><a class="reference internal" href="../../api_reference/flax.linen/module.html">Module</a></li> <li class="toctree-l3"><a class="reference internal" href="../../api_reference/flax.linen/init_apply.html">Init/Apply</a></li> <li class="toctree-l3"><a class="reference internal" href="../../api_reference/flax.linen/layers.html">Layers</a></li> <li class="toctree-l3"><a class="reference internal" href="../../api_reference/flax.linen/activation_functions.html">Activation functions</a></li> <li class="toctree-l3"><a class="reference internal" href="../../api_reference/flax.linen/initializers.html">Initializers</a></li> <li class="toctree-l3"><a class="reference internal" href="../../api_reference/flax.linen/transformations.html">Transformations</a></li> <li class="toctree-l3"><a class="reference internal" href="../../api_reference/flax.linen/inspection.html">Inspection</a></li> <li class="toctree-l3"><a class="reference internal" href="../../api_reference/flax.linen/variable.html">Variable dictionary</a></li> <li class="toctree-l3"><a class="reference internal" href="../../api_reference/flax.linen/spmd.html">SPMD</a></li> <li class="toctree-l3"><a class="reference internal" href="../../api_reference/flax.linen/decorators.html">Decorators</a></li> <li class="toctree-l3"><a class="reference internal" href="../../api_reference/flax.linen/profiling.html">Profiling</a></li> </ul> </details></li> <li class="toctree-l2"><a class="reference internal" href="../../api_reference/flax.serialization.html">flax.serialization package</a></li> <li class="toctree-l2"><a class="reference internal" href="../../api_reference/flax.struct.html">flax.struct package</a></li> <li class="toctree-l2"><a class="reference internal" href="../../api_reference/flax.traceback_util.html">flax.traceback_util package</a></li> <li class="toctree-l2"><a class="reference internal" href="../../api_reference/flax.training.html">flax.training package</a></li> <li class="toctree-l2"><a class="reference internal" href="../../api_reference/flax.traverse_util.html">flax.traverse_util package</a></li> </ul> </details></li> <li class="toctree-l1"><a class="reference external" href="https://flax.readthedocs.io/en/latest/index.html">Flax NNX</a></li> </ul> </div> </nav></div> </div> <div class="sidebar-primary-items__end sidebar-primary__section"> </div> <div id="rtd-footer-container"></div> </div> <main id="main-content" class="bd-main" role="main"> <div class="sbt-scroll-pixel-helper"></div> <div class="bd-content"> <div class="bd-article-container"> <div class="bd-header-article d-print-none"> <div class="header-article-items header-article__inner"> <div class="header-article-items__start"> <div class="header-article-item"><button class="sidebar-toggle primary-toggle btn btn-sm" title="Toggle primary sidebar" data-bs-placement="bottom" data-bs-toggle="tooltip"> <span class="fa-solid fa-bars"></span> </button></div> </div> <div class="header-article-items__end"> <div class="header-article-item"> <div class="article-header-buttons"> <a href="https://github.com/google/flax" target="_blank" class="btn btn-sm btn-source-repository-button" title="Source repository" data-bs-placement="bottom" data-bs-toggle="tooltip" > <span class="btn__icon-container"> <i class="fab fa-github"></i> </span> </a> <div class="dropdown dropdown-download-buttons"> <button class="btn dropdown-toggle" type="button" data-bs-toggle="dropdown" aria-expanded="false" aria-label="Download this page"> <i class="fas fa-download"></i> </button> <ul class="dropdown-menu"> <li><a href="../../_sources/guides/converting_and_upgrading/optax_update_guide.rst" target="_blank" class="btn btn-sm btn-download-source-button dropdown-item" title="Download source file" data-bs-placement="left" data-bs-toggle="tooltip" > <span class="btn__icon-container"> <i class="fas fa-file"></i> </span> <span class="btn__text-container">.rst</span> </a> </li> <li> <button onclick="window.print()" class="btn btn-sm btn-download-pdf-button dropdown-item" title="Print to PDF" data-bs-placement="left" data-bs-toggle="tooltip" > <span class="btn__icon-container"> <i class="fas fa-file-pdf"></i> </span> <span class="btn__text-container">.pdf</span> </button> </li> </ul> </div> <button onclick="toggleFullScreen()" class="btn btn-sm btn-fullscreen-button" title="Fullscreen mode" data-bs-placement="bottom" data-bs-toggle="tooltip" > <span class="btn__icon-container"> <i class="fas fa-expand"></i> </span> </button> <script> document.write(` <button class="btn btn-sm nav-link pst-navbar-icon theme-switch-button" title="light/dark" aria-label="light/dark" data-bs-placement="bottom" data-bs-toggle="tooltip"> <i class="theme-switch fa-solid fa-sun fa-lg" data-mode="light"></i> <i class="theme-switch fa-solid fa-moon fa-lg" data-mode="dark"></i> <i class="theme-switch fa-solid fa-circle-half-stroke fa-lg" data-mode="auto"></i> </button> `); </script> <script> document.write(` <button class="btn btn-sm pst-navbar-icon search-button search-button__button" title="Search" aria-label="Search" data-bs-placement="bottom" data-bs-toggle="tooltip"> <i class="fa-solid fa-magnifying-glass fa-lg"></i> </button> `); </script> <button class="sidebar-toggle secondary-toggle btn btn-sm" title="Toggle secondary sidebar" data-bs-placement="bottom" data-bs-toggle="tooltip"> <span class="fa-solid fa-list"></span> </button> </div></div> </div> </div> </div> <div id="jb-print-docs-body" class="onlyprint"> <h1>Upgrading my codebase to Optax</h1> <!-- Table of contents --> <div id="print-main-content"> <div id="jb-print-toc"> <div> <h2> Contents </h2> </div> <nav aria-label="Page"> <ul class="visible nav section-nav flex-column"> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#replacing-flax-optim-with-optax">Replacing <code class="docutils literal notranslate"><span class="pre">flax.optim</span></code> with <code class="docutils literal notranslate"><span class="pre">optax</span></code></a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#composable-gradient-transformations">Composable Gradient Transformations</a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#weight-decay">Weight Decay</a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#gradient-clipping">Gradient Clipping</a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#learning-rate-schedules">Learning Rate Schedules</a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#multiple-optimizers-updating-a-subset-of-parameters">Multiple Optimizers / Updating a Subset of Parameters</a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#final-words">Final Words</a></li> </ul> </nav> </div> </div> </div> <div id="searchbox"></div> <article class="bd-article"> <div class="section" id="upgrading-my-codebase-to-optax"> <h1>Upgrading my codebase to Optax<a class="headerlink" href="#upgrading-my-codebase-to-optax" title="Permalink to this heading">#</a></h1> <p>We have proposed to replace <code class="xref py py-mod docutils literal notranslate"><span class="pre">flax.optim</span></code> with <a class="reference external" href="https://optax.readthedocs.io">Optax</a> in 2021 with <a class="reference external" href="https://github.com/google/flax/blob/main/docs/flip/1009-optimizer-api.md">FLIP #1009</a> and the Flax optimizers have been removed in v0.6.0 - this guide is targeted towards <code class="xref py py-mod docutils literal notranslate"><span class="pre">flax.optim</span></code> users to help them update their code to Optax.</p> <p>See also Optax’s quick start documentation: <a class="reference external" href="https://optax.readthedocs.io/en/latest/getting_started.html">https://optax.readthedocs.io/en/latest/getting_started.html</a></p> <div class="section" id="replacing-flax-optim-with-optax"> <h2>Replacing <code class="docutils literal notranslate"><span class="pre">flax.optim</span></code> with <code class="docutils literal notranslate"><span class="pre">optax</span></code><a class="headerlink" href="#replacing-flax-optim-with-optax" title="Permalink to this heading">#</a></h2> <p>Optax has drop-in replacements for all of Flax’s optimizers. Refer to Optax’s documentation <a class="reference external" href="https://optax.readthedocs.io/en/latest/api/optimizers.html">Common Optimizers</a> for API details.</p> <p>The usage is very similar, with the difference that <code class="docutils literal notranslate"><span class="pre">optax</span></code> does not keep a copy of the <code class="docutils literal notranslate"><span class="pre">params</span></code>, so they need to be passed around separately. Flax provides the utility <a class="reference internal" href="../../api_reference/flax.training.html#flax.training.train_state.TrainState" title="flax.training.train_state.TrainState"><code class="xref py py-class docutils literal notranslate"><span class="pre">TrainState</span></code></a> to store optimizer state, parameters, and other associated data in a single dataclass (not used in code below).</p> <p><div class="sd-tab-set docutils"> <input checked="checked" id="sd-tab-item-0" name="sd-tab-set-0" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="flax.optim" for="sd-tab-item-0"> flax.optim</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="nd">@jax</span><span class="o">.</span><span class="n">jit</span> <span class="k">def</span> <span class="nf">train_step</span><span class="p">(</span><span class="n">optimizer</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span> <span class="n">grads</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">grad</span><span class="p">(</span><span class="n">loss</span><span class="p">)(</span><span class="n">optimizer</span><span class="o">.</span><span class="n">target</span><span class="p">,</span> <span class="n">batch</span><span class="p">)</span> <span class="k">return</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">apply_gradient</span><span class="p">(</span><span class="n">grads</span><span class="p">)</span> <span class="n">optimizer_def</span> <span class="o">=</span> <span class="n">flax</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Momentum</span><span class="p">(</span> <span class="n">learning_rate</span><span class="p">,</span> <span class="n">momentum</span><span class="p">)</span> <span class="n">optimizer</span> <span class="o">=</span> <span class="n">optimizer_def</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">variables</span><span class="p">[</span><span class="s1">'params'</span><span class="p">])</span> <span class="k">for</span> <span class="n">batch</span> <span class="ow">in</span> <span class="n">get_ds_train</span><span class="p">():</span> <span class="n">optimizer</span> <span class="o">=</span> <span class="n">train_step</span><span class="p">(</span><span class="n">optimizer</span><span class="p">,</span> <span class="n">batch</span><span class="p">)</span> </pre></div> </div> </div> <input id="sd-tab-item-1" name="sd-tab-set-0" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="optax" for="sd-tab-item-1"> optax</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="nd">@jax</span><span class="o">.</span><span class="n">jit</span> <span class="k">def</span> <span class="nf">train_step</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">opt_state</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span> <span class="n">grads</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">grad</span><span class="p">(</span><span class="n">loss</span><span class="p">)(</span><span class="n">params</span><span class="p">,</span> <span class="n">batch</span><span class="p">)</span> <span class="n">updates</span><span class="p">,</span> <span class="n">opt_state</span> <span class="o">=</span> <span class="n">tx</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">grads</span><span class="p">,</span> <span class="n">opt_state</span><span class="p">)</span> <span class="n">params</span> <span class="o">=</span> <span class="n">optax</span><span class="o">.</span><span class="n">apply_updates</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">updates</span><span class="p">)</span> <span class="k">return</span> <span class="n">params</span><span class="p">,</span> <span class="n">opt_state</span> <span class="n">tx</span> <span class="o">=</span> <span class="n">optax</span><span class="o">.</span><span class="n">sgd</span><span class="p">(</span><span class="n">learning_rate</span><span class="p">,</span> <span class="n">momentum</span><span class="p">)</span> <span class="n">params</span> <span class="o">=</span> <span class="n">variables</span><span class="p">[</span><span class="s1">'params'</span><span class="p">]</span> <span class="n">opt_state</span> <span class="o">=</span> <span class="n">tx</span><span class="o">.</span><span class="n">init</span><span class="p">(</span><span class="n">params</span><span class="p">)</span> <span class="k">for</span> <span class="n">batch</span> <span class="ow">in</span> <span class="n">ds_train</span><span class="p">:</span> <span class="n">params</span><span class="p">,</span> <span class="n">opt_state</span> <span class="o">=</span> <span class="n">train_step</span><span class="p">(</span><span class="n">params</span><span class="p">,</span> <span class="n">opt_state</span><span class="p">,</span> <span class="n">batch</span><span class="p">)</span> </pre></div> </div> </div> </div> </p> </div> <div class="section" id="composable-gradient-transformations"> <h2>Composable Gradient Transformations<a class="headerlink" href="#composable-gradient-transformations" title="Permalink to this heading">#</a></h2> <p>The function <a class="reference external" href="https://optax.readthedocs.io/en/latest/api/optimizers.html#optax.sgd"><code class="docutils literal notranslate"><span class="pre">optax.sgd()</span></code></a> used in the code snippet above is simply a wrapper for the sequential application of two gradient transformations. Instead of using this alias, it is common to use <a class="reference external" href="https://optax.readthedocs.io/en/latest/api/combining_optimizers.html#optax.chain"><code class="docutils literal notranslate"><span class="pre">optax.chain()</span></code></a> to combine multiple of these generic building blocks.</p> <p><div class="sd-tab-set docutils"> <input checked="checked" id="sd-tab-item-2" name="sd-tab-set-1" type="radio"> <label class="sd-tab-label" for="sd-tab-item-2"> Pre-defined alias</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="c1"># Note that the aliases follow the convention to use positive</span> <span class="c1"># values for the learning rate by default.</span> <span class="n">tx</span> <span class="o">=</span> <span class="n">optax</span><span class="o">.</span><span class="n">sgd</span><span class="p">(</span><span class="n">learning_rate</span><span class="p">,</span> <span class="n">momentum</span><span class="p">)</span> </pre></div> </div> </div> <input id="sd-tab-item-3" name="sd-tab-set-1" type="radio"> <label class="sd-tab-label" for="sd-tab-item-3"> Combining transformations</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="c1">#</span> <span class="n">tx</span> <span class="o">=</span> <span class="n">optax</span><span class="o">.</span><span class="n">chain</span><span class="p">(</span> <span class="c1"># 1. Step: keep a trace of past updates and add to gradients.</span> <span class="n">optax</span><span class="o">.</span><span class="n">trace</span><span class="p">(</span><span class="n">decay</span><span class="o">=</span><span class="n">momentum</span><span class="p">),</span> <span class="c1"># 2. Step: multiply result from step 1 with negative learning rate.</span> <span class="c1"># Note that `optax.apply_updates()` simply adds the final updates to the</span> <span class="c1"># parameters, so we must make sure to flip the sign here for gradient</span> <span class="c1"># descent.</span> <span class="n">optax</span><span class="o">.</span><span class="n">scale</span><span class="p">(</span><span class="o">-</span><span class="n">learning_rate</span><span class="p">),</span> <span class="p">)</span> </pre></div> </div> </div> </div> </p> </div> <div class="section" id="weight-decay"> <h2>Weight Decay<a class="headerlink" href="#weight-decay" title="Permalink to this heading">#</a></h2> <p>Some of Flax’s optimizers also include a weight decay. In Optax, some optimizers also have a weight decay parameter (such as <a class="reference external" href="https://optax.readthedocs.io/en/latest/api/optimizers.html#optax.adamw"><code class="docutils literal notranslate"><span class="pre">optax.adamw()</span></code></a>), and to others the weight decay can be added as another “gradient transformation” <a class="reference external" href="https://optax.readthedocs.io/en/latest/api/transformations.html#optax.add_decayed_weights"><code class="docutils literal notranslate"><span class="pre">optax.add_decayed_weights()</span></code></a> that adds an update derived from the parameters.</p> <p><div class="sd-tab-set docutils"> <input checked="checked" id="sd-tab-item-4" name="sd-tab-set-2" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="flax.optim" for="sd-tab-item-4"> flax.optim</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">optimizer_def</span> <span class="o">=</span> <span class="n">flax</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span> <span class="n">learning_rate</span><span class="p">,</span> <span class="n">weight_decay</span><span class="o">=</span><span class="n">weight_decay</span><span class="p">)</span> <span class="n">optimizer</span> <span class="o">=</span> <span class="n">optimizer_def</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">variables</span><span class="p">[</span><span class="s1">'params'</span><span class="p">])</span> </pre></div> </div> </div> <input id="sd-tab-item-5" name="sd-tab-set-2" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="optax" for="sd-tab-item-5"> optax</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="c1"># (Note that you could also use `optax.adamw()` in this case)</span> <span class="n">tx</span> <span class="o">=</span> <span class="n">optax</span><span class="o">.</span><span class="n">chain</span><span class="p">(</span> <span class="n">optax</span><span class="o">.</span><span class="n">scale_by_adam</span><span class="p">(),</span> <span class="n">optax</span><span class="o">.</span><span class="n">add_decayed_weights</span><span class="p">(</span><span class="n">weight_decay</span><span class="p">),</span> <span class="c1"># params -= learning_rate * (adam(grads) + params * weight_decay)</span> <span class="n">optax</span><span class="o">.</span><span class="n">scale</span><span class="p">(</span><span class="o">-</span><span class="n">learning_rate</span><span class="p">),</span> <span class="p">)</span> <span class="c1"># Note that you'll need to specify `params` when computing the udpates:</span> <span class="c1"># tx.update(grads, opt_state, params)</span> </pre></div> </div> </div> </div> </p> </div> <div class="section" id="gradient-clipping"> <h2>Gradient Clipping<a class="headerlink" href="#gradient-clipping" title="Permalink to this heading">#</a></h2> <p>Training can be stabilized by clipping gradients to a global norm (<a class="reference external" href="https://arxiv.org/abs/1211.5063">Pascanu et al, 2012</a>). In Flax this is often done by processing the gradients before passing them to the optimizer. With Optax this becomes just another gradient transformation <a class="reference external" href="https://optax.readthedocs.io/en/latest/api/transformations.html#optax.clip_by_global_norm"><code class="docutils literal notranslate"><span class="pre">optax.clip_by_global_norm()</span></code></a>.</p> <p><div class="sd-tab-set docutils"> <input checked="checked" id="sd-tab-item-6" name="sd-tab-set-3" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="flax.optim" for="sd-tab-item-6"> flax.optim</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">train_step</span><span class="p">(</span><span class="n">optimizer</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span> <span class="n">grads</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">grad</span><span class="p">(</span><span class="n">loss</span><span class="p">)(</span><span class="n">optimizer</span><span class="o">.</span><span class="n">target</span><span class="p">,</span> <span class="n">batch</span><span class="p">)</span> <span class="n">grads_flat</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">tree_util</span><span class="o">.</span><span class="n">tree_flatten</span><span class="p">(</span><span class="n">grads</span><span class="p">)</span> <span class="n">global_l2</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="nb">sum</span><span class="p">([</span><span class="n">jnp</span><span class="o">.</span><span class="n">vdot</span><span class="p">(</span><span class="n">p</span><span class="p">,</span> <span class="n">p</span><span class="p">)</span> <span class="k">for</span> <span class="n">p</span> <span class="ow">in</span> <span class="n">grads_flat</span><span class="p">]))</span> <span class="n">g_factor</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">minimum</span><span class="p">(</span><span class="mf">1.0</span><span class="p">,</span> <span class="n">grad_clip_norm</span> <span class="o">/</span> <span class="n">global_l2</span><span class="p">)</span> <span class="n">grads</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">tree_util</span><span class="o">.</span><span class="n">tree_map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">g</span><span class="p">:</span> <span class="n">g</span> <span class="o">*</span> <span class="n">g_factor</span><span class="p">,</span> <span class="n">grads</span><span class="p">)</span> <span class="k">return</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">apply_gradient</span><span class="p">(</span><span class="n">grads</span><span class="p">)</span> </pre></div> </div> </div> <input id="sd-tab-item-7" name="sd-tab-set-3" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="optax" for="sd-tab-item-7"> optax</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">tx</span> <span class="o">=</span> <span class="n">optax</span><span class="o">.</span><span class="n">chain</span><span class="p">(</span> <span class="n">optax</span><span class="o">.</span><span class="n">clip_by_global_norm</span><span class="p">(</span><span class="n">grad_clip_norm</span><span class="p">),</span> <span class="n">optax</span><span class="o">.</span><span class="n">trace</span><span class="p">(</span><span class="n">decay</span><span class="o">=</span><span class="n">momentum</span><span class="p">),</span> <span class="n">optax</span><span class="o">.</span><span class="n">scale</span><span class="p">(</span><span class="o">-</span><span class="n">learning_rate</span><span class="p">),</span> <span class="p">)</span> </pre></div> </div> </div> </div> </p> </div> <div class="section" id="learning-rate-schedules"> <h2>Learning Rate Schedules<a class="headerlink" href="#learning-rate-schedules" title="Permalink to this heading">#</a></h2> <p>For learning rate schedules, Flax allows overwriting hyper parameters when applying the gradients. Optax maintains a step counter and provides this as an argument to a function for scaling the updates added with <a class="reference external" href="https://optax.readthedocs.io/en/latest/api/transformations.html#optax.scale_by_schedule"><code class="docutils literal notranslate"><span class="pre">optax.scale_by_schedule()</span></code></a>. Optax also allows specifying a functions to inject arbitrary scalar values for other gradient updates via <a class="reference external" href="https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html#optax.inject_hyperparams"><code class="docutils literal notranslate"><span class="pre">optax.inject_hyperparams()</span></code></a>.</p> <p>Read more about learning rate schedules in the <span class="xref std std-doc">lr_schedule</span> guide.</p> <p>Read more about schedules defined in Optax under <a class="reference external" href="https://optax.readthedocs.io/en/latest/api/optimizer_schedules.html">Optimizer Schedules</a>. the standard optimizers (like <code class="docutils literal notranslate"><span class="pre">optax.adam()</span></code>, <code class="docutils literal notranslate"><span class="pre">optax.sgd()</span></code> etc.) also accept a learning rate schedule as a parameter for <code class="docutils literal notranslate"><span class="pre">learning_rate</span></code>.</p> <p><div class="sd-tab-set docutils"> <input checked="checked" id="sd-tab-item-8" name="sd-tab-set-4" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="flax.optim" for="sd-tab-item-8"> flax.optim</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span> <span class="nf">train_step</span><span class="p">(</span><span class="n">step</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">,</span> <span class="n">batch</span><span class="p">):</span> <span class="n">grads</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">grad</span><span class="p">(</span><span class="n">loss</span><span class="p">)(</span><span class="n">optimizer</span><span class="o">.</span><span class="n">target</span><span class="p">,</span> <span class="n">batch</span><span class="p">)</span> <span class="k">return</span> <span class="n">step</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">apply_gradient</span><span class="p">(</span><span class="n">grads</span><span class="p">,</span> <span class="n">learning_rate</span><span class="o">=</span><span class="n">schedule</span><span class="p">(</span><span class="n">step</span><span class="p">))</span> </pre></div> </div> </div> <input id="sd-tab-item-9" name="sd-tab-set-4" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="optax" for="sd-tab-item-9"> optax</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">tx</span> <span class="o">=</span> <span class="n">optax</span><span class="o">.</span><span class="n">chain</span><span class="p">(</span> <span class="n">optax</span><span class="o">.</span><span class="n">trace</span><span class="p">(</span><span class="n">decay</span><span class="o">=</span><span class="n">momentum</span><span class="p">),</span> <span class="c1"># Note that we still want a negative value for scaling the updates!</span> <span class="n">optax</span><span class="o">.</span><span class="n">scale_by_schedule</span><span class="p">(</span><span class="k">lambda</span> <span class="n">step</span><span class="p">:</span> <span class="o">-</span><span class="n">schedule</span><span class="p">(</span><span class="n">step</span><span class="p">)),</span> <span class="p">)</span> </pre></div> </div> </div> </div> </p> </div> <div class="section" id="multiple-optimizers-updating-a-subset-of-parameters"> <h2>Multiple Optimizers / Updating a Subset of Parameters<a class="headerlink" href="#multiple-optimizers-updating-a-subset-of-parameters" title="Permalink to this heading">#</a></h2> <p>In Flax, traversals are used to specify which parameters should be updated by an optimizer. And you can combine traversals using <code class="xref py py-class docutils literal notranslate"><span class="pre">flax.optim.MultiOptimizer</span></code> to apply different optimizers on different parameters. The equivalent in Optax is <a class="reference external" href="https://optax.readthedocs.io/en/latest/api/optimizer_wrappers.html#optax.masked"><code class="docutils literal notranslate"><span class="pre">optax.masked()</span></code></a> and <a class="reference external" href="https://optax.readthedocs.io/en/latest/api/combining_optimizers.html#optax.chain"><code class="docutils literal notranslate"><span class="pre">optax.chain()</span></code></a>.</p> <p>Note that the example below is using <a class="reference internal" href="../../api_reference/flax.traverse_util.html#module-flax.traverse_util" title="flax.traverse_util"><code class="xref py py-mod docutils literal notranslate"><span class="pre">flax.traverse_util</span></code></a> to create the boolean masks required by <a class="reference external" href="https://optax.readthedocs.io/en/latest/api/optimizer_wrappers.html#optax.masked"><code class="docutils literal notranslate"><span class="pre">optax.masked()</span></code></a> - alternatively you could also create them manually, or use <a class="reference external" href="https://optax.readthedocs.io/en/latest/api/combining_optimizers.html#optax.multi_transform"><code class="docutils literal notranslate"><span class="pre">optax.multi_transform()</span></code></a> that takes a multivalent pytree to specify gradient transformations.</p> <p>Beware that <a class="reference external" href="https://optax.readthedocs.io/en/latest/api/optimizer_wrappers.html#optax.masked"><code class="docutils literal notranslate"><span class="pre">optax.masked()</span></code></a> flattens the pytree internally and the inner gradient transformations will only be called with that partial flattened view of the params/gradients. This is not a problem usually, but it makes it hard to nest multiple levels of masked gradient transformations (because the inner masks will expect the mask to be defined in terms of the partial flattened view that is not readily available outside the outer mask).</p> <p><div class="sd-tab-set docutils"> <input checked="checked" id="sd-tab-item-10" name="sd-tab-set-5" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="flax.optim" for="sd-tab-item-10"> flax.optim</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">kernels</span> <span class="o">=</span> <span class="n">flax</span><span class="o">.</span><span class="n">traverse_util</span><span class="o">.</span><span class="n">ModelParamTraversal</span><span class="p">(</span><span class="k">lambda</span> <span class="n">p</span><span class="p">,</span> <span class="n">_</span><span class="p">:</span> <span class="s1">'kernel'</span> <span class="ow">in</span> <span class="n">p</span><span class="p">)</span> <span class="n">biases</span> <span class="o">=</span> <span class="n">flax</span><span class="o">.</span><span class="n">traverse_util</span><span class="o">.</span><span class="n">ModelParamTraversal</span><span class="p">(</span><span class="k">lambda</span> <span class="n">p</span><span class="p">,</span> <span class="n">_</span><span class="p">:</span> <span class="s1">'bias'</span> <span class="ow">in</span> <span class="n">p</span><span class="p">)</span> <span class="n">kernel_opt</span> <span class="o">=</span> <span class="n">flax</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Momentum</span><span class="p">(</span><span class="n">learning_rate</span><span class="p">,</span> <span class="n">momentum</span><span class="p">)</span> <span class="n">bias_opt</span> <span class="o">=</span> <span class="n">flax</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">Momentum</span><span class="p">(</span><span class="n">learning_rate</span> <span class="o">*</span> <span class="mf">0.1</span><span class="p">,</span> <span class="n">momentum</span><span class="p">)</span> <span class="n">optimizer</span> <span class="o">=</span> <span class="n">flax</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">MultiOptimizer</span><span class="p">(</span> <span class="p">(</span><span class="n">kernels</span><span class="p">,</span> <span class="n">kernel_opt</span><span class="p">),</span> <span class="p">(</span><span class="n">biases</span><span class="p">,</span> <span class="n">bias_opt</span><span class="p">)</span> <span class="p">)</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">variables</span><span class="p">[</span><span class="s1">'params'</span><span class="p">])</span> </pre></div> </div> </div> <input id="sd-tab-item-11" name="sd-tab-set-5" type="radio"> <label class="sd-tab-label" data-sync-group="tab" data-sync-id="optax" for="sd-tab-item-11"> optax</label><div class="sd-tab-content docutils"> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">kernels</span> <span class="o">=</span> <span class="n">flax</span><span class="o">.</span><span class="n">traverse_util</span><span class="o">.</span><span class="n">ModelParamTraversal</span><span class="p">(</span><span class="k">lambda</span> <span class="n">p</span><span class="p">,</span> <span class="n">_</span><span class="p">:</span> <span class="s1">'kernel'</span> <span class="ow">in</span> <span class="n">p</span><span class="p">)</span> <span class="n">biases</span> <span class="o">=</span> <span class="n">flax</span><span class="o">.</span><span class="n">traverse_util</span><span class="o">.</span><span class="n">ModelParamTraversal</span><span class="p">(</span><span class="k">lambda</span> <span class="n">p</span><span class="p">,</span> <span class="n">_</span><span class="p">:</span> <span class="s1">'bias'</span> <span class="ow">in</span> <span class="n">p</span><span class="p">)</span> <span class="n">all_false</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">tree_util</span><span class="o">.</span><span class="n">tree_map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">_</span><span class="p">:</span> <span class="kc">False</span><span class="p">,</span> <span class="n">params</span><span class="p">)</span> <span class="n">kernels_mask</span> <span class="o">=</span> <span class="n">kernels</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="k">lambda</span> <span class="n">_</span><span class="p">:</span> <span class="kc">True</span><span class="p">,</span> <span class="n">all_false</span><span class="p">)</span> <span class="n">biases_mask</span> <span class="o">=</span> <span class="n">biases</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="k">lambda</span> <span class="n">_</span><span class="p">:</span> <span class="kc">True</span><span class="p">,</span> <span class="n">all_false</span><span class="p">)</span> <span class="n">tx</span> <span class="o">=</span> <span class="n">optax</span><span class="o">.</span><span class="n">chain</span><span class="p">(</span> <span class="n">optax</span><span class="o">.</span><span class="n">trace</span><span class="p">(</span><span class="n">decay</span><span class="o">=</span><span class="n">momentum</span><span class="p">),</span> <span class="n">optax</span><span class="o">.</span><span class="n">masked</span><span class="p">(</span><span class="n">optax</span><span class="o">.</span><span class="n">scale</span><span class="p">(</span><span class="o">-</span><span class="n">learning_rate</span><span class="p">),</span> <span class="n">kernels_mask</span><span class="p">),</span> <span class="n">optax</span><span class="o">.</span><span class="n">masked</span><span class="p">(</span><span class="n">optax</span><span class="o">.</span><span class="n">scale</span><span class="p">(</span><span class="o">-</span><span class="n">learning_rate</span> <span class="o">*</span> <span class="mf">0.1</span><span class="p">),</span> <span class="n">biases_mask</span><span class="p">),</span> <span class="p">)</span> </pre></div> </div> </div> </div> </p> </div> <div class="section" id="final-words"> <h2>Final Words<a class="headerlink" href="#final-words" title="Permalink to this heading">#</a></h2> <p>All above patterns can of course also be mixed and Optax makes it possible to encapsulate all these transformations into a single place outside the main training loop, which makes testing much easier.</p> </div> </div> </article> <footer class="prev-next-footer d-print-none"> <div class="prev-next-area"> <a class="left-prev" href="orbax_upgrade_guide.html" title="previous page"> <i class="fa-solid fa-angle-left"></i> <div class="prev-next-info"> <p class="prev-next-subtitle">previous</p> <p class="prev-next-title">Migrate checkpointing to Orbax</p> </div> </a> <a class="right-next" href="linen_upgrade_guide.html" title="next page"> <div class="prev-next-info"> <p class="prev-next-subtitle">next</p> <p class="prev-next-title">Upgrading my codebase to Linen</p> </div> <i class="fa-solid fa-angle-right"></i> </a> </div> </footer> </div> <div class="bd-sidebar-secondary bd-toc"><div class="sidebar-secondary-items sidebar-secondary__inner"> <div class="sidebar-secondary-item"> <div class="page-toc tocsection onthispage"> <i class="fa-solid fa-list"></i> Contents </div> <nav class="bd-toc-nav page-toc"> <ul class="visible nav section-nav flex-column"> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#replacing-flax-optim-with-optax">Replacing <code class="docutils literal notranslate"><span class="pre">flax.optim</span></code> with <code class="docutils literal notranslate"><span class="pre">optax</span></code></a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#composable-gradient-transformations">Composable Gradient Transformations</a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#weight-decay">Weight Decay</a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#gradient-clipping">Gradient Clipping</a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#learning-rate-schedules">Learning Rate Schedules</a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#multiple-optimizers-updating-a-subset-of-parameters">Multiple Optimizers / Updating a Subset of Parameters</a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#final-words">Final Words</a></li> </ul> </nav></div> </div></div> </div> <footer class="bd-footer-content"> <div class="bd-footer-content__inner container"> <div class="footer-item"> <p class="component-author"> By The Flax authors </p> </div> <div class="footer-item"> <p class="copyright"> © Copyright 2023, The Flax authors. <br/> </p> </div> <div class="footer-item"> </div> <div class="footer-item"> </div> </div> </footer> </main> </div> </div> <!-- Scripts loaded after <body> so the DOM is not blocked --> <script src="../../_static/scripts/bootstrap.js?digest=dfe6caa3a7d634c4db9b"></script> <script src="../../_static/scripts/pydata-sphinx-theme.js?digest=dfe6caa3a7d634c4db9b"></script> <footer class="bd-footer"> </footer> </body> </html>