CINXE.COM

Initializers

<!DOCTYPE html> <html lang="en" data-content_root="" > <head> <meta charset="utf-8" /> <meta name="viewport" content="width=device-width, initial-scale=1.0" /> <title>Initializers</title> <script data-cfasync="false"> document.documentElement.dataset.mode = localStorage.getItem("mode") || ""; document.documentElement.dataset.theme = localStorage.getItem("theme") || ""; </script> <!-- Loaded before other Sphinx assets --> <link href="../../_static/styles/theme.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" /> <link href="../../_static/styles/bootstrap.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" /> <link href="../../_static/styles/pydata-sphinx-theme.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" /> <link href="../../_static/vendor/fontawesome/6.5.2/css/all.min.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" /> <link rel="preload" as="font" type="font/woff2" crossorigin href="../../_static/vendor/fontawesome/6.5.2/webfonts/fa-solid-900.woff2" /> <link rel="preload" as="font" type="font/woff2" crossorigin href="../../_static/vendor/fontawesome/6.5.2/webfonts/fa-brands-400.woff2" /> <link rel="preload" as="font" type="font/woff2" crossorigin href="../../_static/vendor/fontawesome/6.5.2/webfonts/fa-regular-400.woff2" /> <link rel="stylesheet" type="text/css" href="../../_static/pygments.css" /> <link rel="stylesheet" type="text/css" href="../../_static/styles/sphinx-book-theme.css" /> <link rel="stylesheet" type="text/css" href="../../_static/mystnb.4510f1fc1dee50b3e5859aac5469c37c29e427902b24a333a5f9fcb2f0b3ac41.css" /> <link rel="stylesheet" type="text/css" href="../../_static/sphinx-design.5ea377869091fd0449014c60fc090103.min.css" /> <link rel="stylesheet" type="text/css" href="../../_static/css/flax_theme.css" /> <!-- Pre-loaded scripts that we'll load fully later --> <link rel="preload" as="script" href="../../_static/scripts/bootstrap.js?digest=dfe6caa3a7d634c4db9b" /> <link rel="preload" as="script" href="../../_static/scripts/pydata-sphinx-theme.js?digest=dfe6caa3a7d634c4db9b" /> <script src="../../_static/vendor/fontawesome/6.5.2/js/all.min.js?digest=dfe6caa3a7d634c4db9b"></script> <script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script> <script src="../../_static/jquery.js"></script> <script src="../../_static/underscore.js"></script> <script src="../../_static/_sphinx_javascript_frameworks_compat.js"></script> <script src="../../_static/doctools.js"></script> <script src="../../_static/sphinx_highlight.js"></script> <script src="../../_static/scripts/sphinx-book-theme.js"></script> <script src="../../_static/design-tabs.js"></script> <script>window.MathJax = {"options": {"processHtmlClass": "tex2jax_process|mathjax_process|math|output_area"}}</script> <script defer="defer" src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script> <script>DOCUMENTATION_OPTIONS.pagename = 'api_reference/flax.linen/initializers';</script> <link rel="shortcut icon" href="../../_static/flax.png"/> <link rel="index" title="Index" href="../../genindex.html" /> <link rel="search" title="Search" href="../../search.html" /> <link rel="next" title="Transformations" href="transformations.html" /> <link rel="prev" title="Activation functions" href="activation_functions.html" /> <meta name="viewport" content="width=device-width, initial-scale=1"/> <meta name="docsearch:language" content="en"/> <script async type="text/javascript" src="/_/static/javascript/readthedocs-addons.js"></script><meta name="readthedocs-project-slug" content="flax-linen" /><meta name="readthedocs-version-slug" content="latest" /><meta name="readthedocs-resolver-filename" content="/api_reference/flax.linen/initializers.html" /><meta name="readthedocs-http-status" content="200" /></head> <body data-bs-spy="scroll" data-bs-target=".bd-toc-nav" data-offset="180" data-bs-root-margin="0px 0px -60%" data-default-mode=""> <div id="pst-skip-link" class="skip-link d-print-none"><a href="#main-content">Skip to main content</a></div> <div id="pst-scroll-pixel-helper"></div> <button type="button" class="btn rounded-pill" id="pst-back-to-top"> <i class="fa-solid fa-arrow-up"></i>Back to top</button> <input type="checkbox" class="sidebar-toggle" id="pst-primary-sidebar-checkbox"/> <label class="overlay overlay-primary" for="pst-primary-sidebar-checkbox"></label> <input type="checkbox" class="sidebar-toggle" id="pst-secondary-sidebar-checkbox"/> <label class="overlay overlay-secondary" for="pst-secondary-sidebar-checkbox"></label> <div class="search-button__wrapper"> <div class="search-button__overlay"></div> <div class="search-button__search-container"> <form class="bd-search d-flex align-items-center" action="../../search.html" method="get"> <i class="fa-solid fa-magnifying-glass"></i> <input type="search" class="form-control" name="q" id="search-input" placeholder="Search..." aria-label="Search..." autocomplete="off" autocorrect="off" autocapitalize="off" spellcheck="false"/> <span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd>K</kbd></span> </form></div> </div> <div class="pst-async-banner-revealer d-none"> <aside id="bd-header-version-warning" class="d-none d-print-none" aria-label="Version warning"></aside> </div> <aside class="bd-header-announcement" aria-label="Announcement"> <div class="bd-header-announcement__content"> <a href="https://flax.readthedocs.io/en/latest/index.html" style="text-decoration: none; color: white;" > This site covers the old Flax Linen API. <span style="color: lightgray;">[Explore the new <b>Flax NNX</b> API ✨]</span> </a> </div> </aside> <header class="bd-header navbar navbar-expand-lg bd-navbar d-print-none"> </header> <div class="bd-container"> <div class="bd-container__inner bd-page-width"> <div class="bd-sidebar-primary bd-sidebar"> <div class="sidebar-header-items sidebar-primary__section"> </div> <div class="sidebar-primary-items__start sidebar-primary__section"> <div class="sidebar-primary-item"> <a class="navbar-brand logo" href="../../index.html"> <img src="../../_static/flax.png" class="logo__image only-light" alt=" - Home"/> <script>document.write(`<img src="../../_static/flax.png" class="logo__image only-dark" alt=" - Home"/>`);</script> </a></div> <div class="sidebar-primary-item"> <script> document.write(` <button class="btn search-button-field search-button__button" title="Search" aria-label="Search" data-bs-placement="bottom" data-bs-toggle="tooltip"> <i class="fa-solid fa-magnifying-glass"></i> <span class="search-button__default-text">Search</span> <span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd class="kbd-shortcut__modifier">K</kbd></span> </button> `); </script></div> <div class="sidebar-primary-item"><nav class="bd-links bd-docs-nav" aria-label="Main"> <div class="bd-toc-item navbar-nav active"> <ul class="current nav bd-sidenav"> <li class="toctree-l1"><a class="reference internal" href="../../quick_start.html">Quick start</a></li> <li class="toctree-l1"><a class="reference internal" href="../../guides/flax_fundamentals/flax_basics.html">Flax Basics</a></li> <li class="toctree-l1 has-children"><a class="reference internal" href="../../guides/index.html">Guides</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l2 has-children"><a class="reference internal" href="../../guides/flax_fundamentals/index.html">Flax fundamentals</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l3"><a class="reference external" href="https://jax.readthedocs.io/en/latest/jax-101/index.html">JAX 101</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/flax_fundamentals/flax_basics.html">Flax Basics</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/flax_fundamentals/state_params.html">Managing Parameters and State</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/flax_fundamentals/setup_or_nncompact.html"><code class="docutils literal notranslate"><span class="pre">setup</span></code> vs <code class="docutils literal notranslate"><span class="pre">compact</span></code></a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/flax_fundamentals/arguments.html">Dealing with Flax Module arguments</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/flax_fundamentals/rng_guide.html">Randomness and PRNGs in Flax</a></li> </ul> </details></li> <li class="toctree-l2 has-children"><a class="reference internal" href="../../guides/data_preprocessing/index.html">Data preprocessing</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l3"><a class="reference internal" href="../../guides/data_preprocessing/full_eval.html">Processing the entire Dataset</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/data_preprocessing/loading_datasets.html">Loading datasets</a></li> </ul> </details></li> <li class="toctree-l2 has-children"><a class="reference internal" href="../../guides/training_techniques/index.html">Training techniques</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l3"><a class="reference internal" href="../../guides/training_techniques/batch_norm.html">Batch normalization</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/training_techniques/dropout.html">Dropout</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/training_techniques/lr_schedule.html">Learning rate scheduling</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/training_techniques/transfer_learning.html">Transfer learning</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/training_techniques/use_checkpointing.html">Save and load checkpoints</a></li> </ul> </details></li> <li class="toctree-l2 has-children"><a class="reference internal" href="../../guides/parallel_training/index.html">Parallel training</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l3"><a class="reference internal" href="../../guides/parallel_training/ensembling.html">Ensembling on multiple devices</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/parallel_training/flax_on_pjit.html">Scale up Flax Modules on multiple devices</a></li> </ul> </details></li> <li class="toctree-l2 has-children"><a class="reference internal" href="../../guides/model_inspection/index.html">Model inspection</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l3"><a class="reference internal" href="../../guides/model_inspection/model_surgery.html">Model surgery</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/model_inspection/extracting_intermediates.html">Extracting intermediate values</a></li> </ul> </details></li> <li class="toctree-l2 has-children"><a class="reference internal" href="../../guides/converting_and_upgrading/index.html">Converting and upgrading</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l3"><a class="reference internal" href="../../guides/converting_and_upgrading/haiku_migration_guide.html">Migrating from Haiku to Flax</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/converting_and_upgrading/convert_pytorch_to_flax.html">Convert PyTorch models to Flax</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/converting_and_upgrading/orbax_upgrade_guide.html">Migrate checkpointing to Orbax</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/converting_and_upgrading/optax_update_guide.html">Upgrading my codebase to Optax</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/converting_and_upgrading/linen_upgrade_guide.html">Upgrading my codebase to Linen</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/converting_and_upgrading/rnncell_upgrade_guide.html">RNNCellBase Upgrade Guide</a></li> <li class="toctree-l3"><a class="reference internal" href="../../guides/converting_and_upgrading/regular_dict_upgrade_guide.html">Migrate to regular dicts</a></li> </ul> </details></li> <li class="toctree-l2 has-children"><a class="reference internal" href="../../guides/quantization/index.html">Quantization</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l3"><a class="reference internal" href="../../guides/quantization/fp8_basics.html">User Guide on Using FP8</a></li> </ul> </details></li> <li class="toctree-l2"><a class="reference internal" href="../../guides/flax_sharp_bits.html">The Sharp Bits</a></li> </ul> </details></li> <li class="toctree-l1 has-children"><a class="reference internal" href="../../examples/index.html">Examples</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l2"><a class="reference internal" href="../../examples/core_examples.html">Core examples</a></li> <li class="toctree-l2"><a class="reference internal" href="../../examples/google_research_examples.html">Google Research examples</a></li> <li class="toctree-l2"><a class="reference internal" href="../../examples/repositories_that_use_flax.html">Repositories that use Flax</a></li> <li class="toctree-l2"><a class="reference internal" href="../../examples/community_examples.html">Community examples</a></li> </ul> </details></li> <li class="toctree-l1"><a class="reference internal" href="../../glossary.html">Glossary</a></li> <li class="toctree-l1"><a class="reference internal" href="../../faq.html">Frequently Asked Questions (FAQ)</a></li> <li class="toctree-l1 has-children"><a class="reference internal" href="../../developer_notes/index.html">Developer notes</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l2"><a class="reference internal" href="../../developer_notes/module_lifecycle.html">The Flax Module lifecycle</a></li> <li class="toctree-l2"><a class="reference internal" href="../../developer_notes/lift.html">Lifted transformations</a></li> <li class="toctree-l2"><a class="reference external" href="https://github.com/google/flax/tree/main/docs/flip">FLIPs</a></li> </ul> </details></li> <li class="toctree-l1"><a class="reference internal" href="../../philosophy.html">The Flax philosophy</a></li> <li class="toctree-l1"><a class="reference internal" href="../../contributing.html">How to contribute</a></li> <li class="toctree-l1 current active has-children"><a class="reference internal" href="../index.html">API Reference</a><details open="open"><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul class="current"> <li class="toctree-l2"><a class="reference internal" href="../flax.config.html">flax.config package</a></li> <li class="toctree-l2"><a class="reference internal" href="../flax.core.frozen_dict.html">flax.core.frozen_dict package</a></li> <li class="toctree-l2"><a class="reference internal" href="../flax.cursor.html">flax.cursor package</a></li> <li class="toctree-l2"><a class="reference internal" href="../flax.errors.html">flax.errors package</a></li> <li class="toctree-l2"><a class="reference internal" href="../flax.jax_utils.html">flax.jax_utils package</a></li> <li class="toctree-l2 current active has-children"><a class="reference internal" href="index.html">flax.linen</a><details open="open"><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul class="current"> <li class="toctree-l3"><a class="reference internal" href="module.html">Module</a></li> <li class="toctree-l3"><a class="reference internal" href="init_apply.html">Init/Apply</a></li> <li class="toctree-l3"><a class="reference internal" href="layers.html">Layers</a></li> <li class="toctree-l3"><a class="reference internal" href="activation_functions.html">Activation functions</a></li> <li class="toctree-l3 current active"><a class="current reference internal" href="#">Initializers</a></li> <li class="toctree-l3"><a class="reference internal" href="transformations.html">Transformations</a></li> <li class="toctree-l3"><a class="reference internal" href="inspection.html">Inspection</a></li> <li class="toctree-l3"><a class="reference internal" href="variable.html">Variable dictionary</a></li> <li class="toctree-l3"><a class="reference internal" href="spmd.html">SPMD</a></li> <li class="toctree-l3"><a class="reference internal" href="decorators.html">Decorators</a></li> <li class="toctree-l3"><a class="reference internal" href="profiling.html">Profiling</a></li> </ul> </details></li> <li class="toctree-l2"><a class="reference internal" href="../flax.serialization.html">flax.serialization package</a></li> <li class="toctree-l2"><a class="reference internal" href="../flax.struct.html">flax.struct package</a></li> <li class="toctree-l2"><a class="reference internal" href="../flax.traceback_util.html">flax.traceback_util package</a></li> <li class="toctree-l2"><a class="reference internal" href="../flax.training.html">flax.training package</a></li> <li class="toctree-l2"><a class="reference internal" href="../flax.traverse_util.html">flax.traverse_util package</a></li> </ul> </details></li> <li class="toctree-l1"><a class="reference external" href="https://flax.readthedocs.io/en/latest/index.html">Flax NNX</a></li> </ul> </div> </nav></div> </div> <div class="sidebar-primary-items__end sidebar-primary__section"> </div> <div id="rtd-footer-container"></div> </div> <main id="main-content" class="bd-main" role="main"> <div class="sbt-scroll-pixel-helper"></div> <div class="bd-content"> <div class="bd-article-container"> <div class="bd-header-article d-print-none"> <div class="header-article-items header-article__inner"> <div class="header-article-items__start"> <div class="header-article-item"><button class="sidebar-toggle primary-toggle btn btn-sm" title="Toggle primary sidebar" data-bs-placement="bottom" data-bs-toggle="tooltip"> <span class="fa-solid fa-bars"></span> </button></div> </div> <div class="header-article-items__end"> <div class="header-article-item"> <div class="article-header-buttons"> <a href="https://github.com/google/flax" target="_blank" class="btn btn-sm btn-source-repository-button" title="Source repository" data-bs-placement="bottom" data-bs-toggle="tooltip" > <span class="btn__icon-container"> <i class="fab fa-github"></i> </span> </a> <div class="dropdown dropdown-download-buttons"> <button class="btn dropdown-toggle" type="button" data-bs-toggle="dropdown" aria-expanded="false" aria-label="Download this page"> <i class="fas fa-download"></i> </button> <ul class="dropdown-menu"> <li><a href="../../_sources/api_reference/flax.linen/initializers.rst" target="_blank" class="btn btn-sm btn-download-source-button dropdown-item" title="Download source file" data-bs-placement="left" data-bs-toggle="tooltip" > <span class="btn__icon-container"> <i class="fas fa-file"></i> </span> <span class="btn__text-container">.rst</span> </a> </li> <li> <button onclick="window.print()" class="btn btn-sm btn-download-pdf-button dropdown-item" title="Print to PDF" data-bs-placement="left" data-bs-toggle="tooltip" > <span class="btn__icon-container"> <i class="fas fa-file-pdf"></i> </span> <span class="btn__text-container">.pdf</span> </button> </li> </ul> </div> <button onclick="toggleFullScreen()" class="btn btn-sm btn-fullscreen-button" title="Fullscreen mode" data-bs-placement="bottom" data-bs-toggle="tooltip" > <span class="btn__icon-container"> <i class="fas fa-expand"></i> </span> </button> <script> document.write(` <button class="btn btn-sm nav-link pst-navbar-icon theme-switch-button" title="light/dark" aria-label="light/dark" data-bs-placement="bottom" data-bs-toggle="tooltip"> <i class="theme-switch fa-solid fa-sun fa-lg" data-mode="light"></i> <i class="theme-switch fa-solid fa-moon fa-lg" data-mode="dark"></i> <i class="theme-switch fa-solid fa-circle-half-stroke fa-lg" data-mode="auto"></i> </button> `); </script> <script> document.write(` <button class="btn btn-sm pst-navbar-icon search-button search-button__button" title="Search" aria-label="Search" data-bs-placement="bottom" data-bs-toggle="tooltip"> <i class="fa-solid fa-magnifying-glass fa-lg"></i> </button> `); </script> <button class="sidebar-toggle secondary-toggle btn btn-sm" title="Toggle secondary sidebar" data-bs-placement="bottom" data-bs-toggle="tooltip"> <span class="fa-solid fa-list"></span> </button> </div></div> </div> </div> </div> <div id="jb-print-docs-body" class="onlyprint"> <h1>Initializers</h1> <!-- Table of contents --> <div id="print-main-content"> <div id="jb-print-toc"> <div> <h2> Contents </h2> </div> <nav aria-label="Page"> <ul class="visible nav section-nav flex-column"> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.initializers.constant"><code class="docutils literal notranslate"><span class="pre">constant()</span></code></a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.initializers.delta_orthogonal"><code class="docutils literal notranslate"><span class="pre">delta_orthogonal()</span></code></a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.initializers.glorot_normal"><code class="docutils literal notranslate"><span class="pre">glorot_normal()</span></code></a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.initializers.glorot_uniform"><code class="docutils literal notranslate"><span class="pre">glorot_uniform()</span></code></a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.initializers.he_normal"><code class="docutils literal notranslate"><span class="pre">he_normal()</span></code></a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.initializers.he_uniform"><code class="docutils literal notranslate"><span class="pre">he_uniform()</span></code></a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.initializers.kaiming_normal"><code class="docutils literal notranslate"><span class="pre">kaiming_normal()</span></code></a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.initializers.kaiming_uniform"><code class="docutils literal notranslate"><span class="pre">kaiming_uniform()</span></code></a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.initializers.lecun_normal"><code class="docutils literal notranslate"><span class="pre">lecun_normal()</span></code></a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.initializers.lecun_uniform"><code class="docutils literal notranslate"><span class="pre">lecun_uniform()</span></code></a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.initializers.normal"><code class="docutils literal notranslate"><span class="pre">normal()</span></code></a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.initializers.truncated_normal"><code class="docutils literal notranslate"><span class="pre">truncated_normal()</span></code></a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.initializers.ones"><code class="docutils literal notranslate"><span class="pre">ones()</span></code></a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.initializers.ones_init"><code class="docutils literal notranslate"><span class="pre">ones_init()</span></code></a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.initializers.orthogonal"><code class="docutils literal notranslate"><span class="pre">orthogonal()</span></code></a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.initializers.uniform"><code class="docutils literal notranslate"><span class="pre">uniform()</span></code></a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.initializers.variance_scaling"><code class="docutils literal notranslate"><span class="pre">variance_scaling()</span></code></a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.initializers.xavier_normal"><code class="docutils literal notranslate"><span class="pre">xavier_normal()</span></code></a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.initializers.xavier_uniform"><code class="docutils literal notranslate"><span class="pre">xavier_uniform()</span></code></a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.initializers.zeros"><code class="docutils literal notranslate"><span class="pre">zeros()</span></code></a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.initializers.zeros_init"><code class="docutils literal notranslate"><span class="pre">zeros_init()</span></code></a></li> </ul> </nav> </div> </div> </div> <div id="searchbox"></div> <article class="bd-article"> <div class="section" id="module-flax.linen.initializers"> <span id="initializers"></span><h1>Initializers<a class="headerlink" href="#module-flax.linen.initializers" title="Permalink to this heading">#</a></h1> <p>Initializers for Flax.</p> <dl class="py function"> <dt class="sig sig-object py" id="flax.linen.initializers.constant"> <span class="sig-prename descclassname"><span class="pre">flax.linen.initializers.</span></span><span class="sig-name descname"><span class="pre">constant</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">value</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype=&lt;class</span> <span class="pre">'jax.numpy.float64'&gt;</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#flax.linen.initializers.constant" title="Permalink to this definition">#</a></dt> <dd><p>Builds an initializer that returns arrays full of a constant <code class="docutils literal notranslate"><span class="pre">value</span></code>.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>value</strong> – the constant value with which to fill the initializer.</p></li> <li><p><strong>dtype</strong> – optional; the initializer’s default dtype.</p></li> </ul> </dd> </dl> <div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">import</span> <span class="nn">jax</span><span class="o">,</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span> <span class="gp">&gt;&gt;&gt; </span><span class="n">initializer</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">initializers</span><span class="o">.</span><span class="n">constant</span><span class="p">(</span><span class="o">-</span><span class="mi">7</span><span class="p">)</span> <span class="gp">&gt;&gt;&gt; </span><span class="n">initializer</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">42</span><span class="p">),</span> <span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">jnp</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="go">Array([[-7., -7., -7.],</span> <span class="go"> [-7., -7., -7.]], dtype=float32)</span> </pre></div> </div> </dd></dl> <dl class="py function"> <dt class="sig sig-object py" id="flax.linen.initializers.delta_orthogonal"> <span class="sig-prename descclassname"><span class="pre">flax.linen.initializers.</span></span><span class="sig-name descname"><span class="pre">delta_orthogonal</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">scale=1.0</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">column_axis=-1</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype=&lt;class</span> <span class="pre">'jax.numpy.float64'&gt;</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#flax.linen.initializers.delta_orthogonal" title="Permalink to this definition">#</a></dt> <dd><p>Builds an initializer for delta orthogonal kernels.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>scale</strong> – the upper bound of the uniform distribution.</p></li> <li><p><strong>column_axis</strong> – the axis that contains the columns that should be orthogonal.</p></li> <li><p><strong>dtype</strong> – the default dtype of the weights.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>A <a class="reference external" href="https://arxiv.org/abs/1806.05393">delta orthogonal initializer</a>. The shape passed to the initializer must be 3D, 4D, or 5D.</p> </dd> </dl> <p>Examples:</p> <div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">import</span> <span class="nn">jax</span><span class="o">,</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span> <span class="gp">&gt;&gt;&gt; </span><span class="n">initializer</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">initializers</span><span class="o">.</span><span class="n">delta_orthogonal</span><span class="p">()</span> <span class="gp">&gt;&gt;&gt; </span><span class="n">initializer</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">42</span><span class="p">),</span> <span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">jnp</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="go">Array([[[ 0. , 0. , 0. ],</span> <span class="go"> [ 0. , 0. , 0. ],</span> <span class="go"> [ 0. , 0. , 0. ]],</span> <span class="go"> [[ 0.27858758, -0.7949833 , -0.53887904],</span> <span class="go"> [ 0.9120717 , 0.04322892, 0.40774566],</span> <span class="go"> [-0.30085585, -0.6050892 , 0.73712474]],</span> <span class="go"> [[ 0. , 0. , 0. ],</span> <span class="go"> [ 0. , 0. , 0. ],</span> <span class="go"> [ 0. , 0. , 0. ]]], dtype=float32)</span> </pre></div> </div> </dd></dl> <dl class="py function"> <dt class="sig sig-object py" id="flax.linen.initializers.glorot_normal"> <span class="sig-prename descclassname"><span class="pre">flax.linen.initializers.</span></span><span class="sig-name descname"><span class="pre">glorot_normal</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">in_axis=-2</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">out_axis=-1</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch_axis=()</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype=&lt;class</span> <span class="pre">'jax.numpy.float64'&gt;</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#flax.linen.initializers.glorot_normal" title="Permalink to this definition">#</a></dt> <dd><p>Builds a Glorot normal initializer (aka Xavier normal initializer).</p> <p>A <a class="reference external" href="http://proceedings.mlr.press/v9/glorot10a.html">Glorot normal initializer</a> is a specialization of <code class="xref py py-func docutils literal notranslate"><span class="pre">jax.nn.initializers.variance_scaling()</span></code> where <code class="docutils literal notranslate"><span class="pre">scale</span> <span class="pre">=</span> <span class="pre">1.0</span></code>, <code class="docutils literal notranslate"><span class="pre">mode=&quot;fan_avg&quot;</span></code>, and <code class="docutils literal notranslate"><span class="pre">distribution=&quot;truncated_normal&quot;</span></code>.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>in_axis</strong> – axis or sequence of axes of the input dimension in the weights array.</p></li> <li><p><strong>out_axis</strong> – axis or sequence of axes of the output dimension in the weights array.</p></li> <li><p><strong>batch_axis</strong> – axis or sequence of axes in the weight array that should be ignored.</p></li> <li><p><strong>dtype</strong> – the dtype of the weights.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>An initializer.</p> </dd> </dl> <p>Examples:</p> <div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">import</span> <span class="nn">jax</span><span class="o">,</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span> <span class="gp">&gt;&gt;&gt; </span><span class="n">initializer</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">initializers</span><span class="o">.</span><span class="n">glorot_normal</span><span class="p">()</span> <span class="gp">&gt;&gt;&gt; </span><span class="n">initializer</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">42</span><span class="p">),</span> <span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">jnp</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="go">Array([[ 0.41770416, 0.75262755, 0.7619329 ],</span> <span class="go"> [-0.5516644 , -0.6028657 , 0.08661086]], dtype=float32)</span> </pre></div> </div> </dd></dl> <dl class="py function"> <dt class="sig sig-object py" id="flax.linen.initializers.glorot_uniform"> <span class="sig-prename descclassname"><span class="pre">flax.linen.initializers.</span></span><span class="sig-name descname"><span class="pre">glorot_uniform</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">in_axis=-2</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">out_axis=-1</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch_axis=()</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype=&lt;class</span> <span class="pre">'jax.numpy.float64'&gt;</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#flax.linen.initializers.glorot_uniform" title="Permalink to this definition">#</a></dt> <dd><p>Builds a Glorot uniform initializer (aka Xavier uniform initializer).</p> <p>A <a class="reference external" href="http://proceedings.mlr.press/v9/glorot10a.html">Glorot uniform initializer</a> is a specialization of <code class="xref py py-func docutils literal notranslate"><span class="pre">jax.nn.initializers.variance_scaling()</span></code> where <code class="docutils literal notranslate"><span class="pre">scale</span> <span class="pre">=</span> <span class="pre">1.0</span></code>, <code class="docutils literal notranslate"><span class="pre">mode=&quot;fan_avg&quot;</span></code>, and <code class="docutils literal notranslate"><span class="pre">distribution=&quot;uniform&quot;</span></code>.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>in_axis</strong> – axis or sequence of axes of the input dimension in the weights array.</p></li> <li><p><strong>out_axis</strong> – axis or sequence of axes of the output dimension in the weights array.</p></li> <li><p><strong>batch_axis</strong> – axis or sequence of axes in the weight array that should be ignored.</p></li> <li><p><strong>dtype</strong> – the dtype of the weights.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>An initializer.</p> </dd> </dl> <p>Examples:</p> <div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">import</span> <span class="nn">jax</span><span class="o">,</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span> <span class="gp">&gt;&gt;&gt; </span><span class="n">initializer</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">initializers</span><span class="o">.</span><span class="n">glorot_uniform</span><span class="p">()</span> <span class="gp">&gt;&gt;&gt; </span><span class="n">initializer</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">42</span><span class="p">),</span> <span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">jnp</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="go">Array([[ 0.50350785, 0.8088631 , 0.81566876],</span> <span class="go"> [-0.6393332 , -0.6865721 , 0.11003882]], dtype=float32)</span> </pre></div> </div> </dd></dl> <dl class="py function"> <dt class="sig sig-object py" id="flax.linen.initializers.he_normal"> <span class="sig-prename descclassname"><span class="pre">flax.linen.initializers.</span></span><span class="sig-name descname"><span class="pre">he_normal</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">in_axis=-2</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">out_axis=-1</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch_axis=()</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype=&lt;class</span> <span class="pre">'jax.numpy.float64'&gt;</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#flax.linen.initializers.he_normal" title="Permalink to this definition">#</a></dt> <dd><p>Builds a He normal initializer (aka Kaiming normal initializer).</p> <p>A <a class="reference external" href="https://arxiv.org/abs/1502.01852">He normal initializer</a> is a specialization of <code class="xref py py-func docutils literal notranslate"><span class="pre">jax.nn.initializers.variance_scaling()</span></code> where <code class="docutils literal notranslate"><span class="pre">scale</span> <span class="pre">=</span> <span class="pre">2.0</span></code>, <code class="docutils literal notranslate"><span class="pre">mode=&quot;fan_in&quot;</span></code>, and <code class="docutils literal notranslate"><span class="pre">distribution=&quot;truncated_normal&quot;</span></code>.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>in_axis</strong> – axis or sequence of axes of the input dimension in the weights array.</p></li> <li><p><strong>out_axis</strong> – axis or sequence of axes of the output dimension in the weights array.</p></li> <li><p><strong>batch_axis</strong> – axis or sequence of axes in the weight array that should be ignored.</p></li> <li><p><strong>dtype</strong> – the dtype of the weights.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>An initializer.</p> </dd> </dl> <p>Examples:</p> <div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">import</span> <span class="nn">jax</span><span class="o">,</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span> <span class="gp">&gt;&gt;&gt; </span><span class="n">initializer</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">initializers</span><span class="o">.</span><span class="n">he_normal</span><span class="p">()</span> <span class="gp">&gt;&gt;&gt; </span><span class="n">initializer</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">42</span><span class="p">),</span> <span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">jnp</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="go">Array([[ 0.6604483 , 1.1900088 , 1.2047218 ],</span> <span class="go"> [-0.87225807, -0.95321447, 0.1369438 ]], dtype=float32)</span> </pre></div> </div> </dd></dl> <dl class="py function"> <dt class="sig sig-object py" id="flax.linen.initializers.he_uniform"> <span class="sig-prename descclassname"><span class="pre">flax.linen.initializers.</span></span><span class="sig-name descname"><span class="pre">he_uniform</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">in_axis=-2</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">out_axis=-1</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch_axis=()</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype=&lt;class</span> <span class="pre">'jax.numpy.float64'&gt;</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#flax.linen.initializers.he_uniform" title="Permalink to this definition">#</a></dt> <dd><p>Builds a He uniform initializer (aka Kaiming uniform initializer).</p> <p>A <a class="reference external" href="https://arxiv.org/abs/1502.01852">He uniform initializer</a> is a specialization of <code class="xref py py-func docutils literal notranslate"><span class="pre">jax.nn.initializers.variance_scaling()</span></code> where <code class="docutils literal notranslate"><span class="pre">scale</span> <span class="pre">=</span> <span class="pre">2.0</span></code>, <code class="docutils literal notranslate"><span class="pre">mode=&quot;fan_in&quot;</span></code>, and <code class="docutils literal notranslate"><span class="pre">distribution=&quot;uniform&quot;</span></code>.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>in_axis</strong> – axis or sequence of axes of the input dimension in the weights array.</p></li> <li><p><strong>out_axis</strong> – axis or sequence of axes of the output dimension in the weights array.</p></li> <li><p><strong>batch_axis</strong> – axis or sequence of axes in the weight array that should be ignored.</p></li> <li><p><strong>dtype</strong> – the dtype of the weights.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>An initializer.</p> </dd> </dl> <p>Examples:</p> <div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">import</span> <span class="nn">jax</span><span class="o">,</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span> <span class="gp">&gt;&gt;&gt; </span><span class="n">initializer</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">initializers</span><span class="o">.</span><span class="n">he_uniform</span><span class="p">()</span> <span class="gp">&gt;&gt;&gt; </span><span class="n">initializer</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">42</span><span class="p">),</span> <span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">jnp</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="go">Array([[ 0.79611576, 1.2789248 , 1.2896855 ],</span> <span class="go"> [-1.0108745 , -1.0855657 , 0.17398663]], dtype=float32)</span> </pre></div> </div> </dd></dl> <dl class="py function"> <dt class="sig sig-object py" id="flax.linen.initializers.kaiming_normal"> <span class="sig-prename descclassname"><span class="pre">flax.linen.initializers.</span></span><span class="sig-name descname"><span class="pre">kaiming_normal</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">in_axis=-2</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">out_axis=-1</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch_axis=()</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype=&lt;class</span> <span class="pre">'jax.numpy.float64'&gt;</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#flax.linen.initializers.kaiming_normal" title="Permalink to this definition">#</a></dt> <dd><p>Builds a He normal initializer (aka Kaiming normal initializer).</p> <p>A <a class="reference external" href="https://arxiv.org/abs/1502.01852">He normal initializer</a> is a specialization of <code class="xref py py-func docutils literal notranslate"><span class="pre">jax.nn.initializers.variance_scaling()</span></code> where <code class="docutils literal notranslate"><span class="pre">scale</span> <span class="pre">=</span> <span class="pre">2.0</span></code>, <code class="docutils literal notranslate"><span class="pre">mode=&quot;fan_in&quot;</span></code>, and <code class="docutils literal notranslate"><span class="pre">distribution=&quot;truncated_normal&quot;</span></code>.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>in_axis</strong> – axis or sequence of axes of the input dimension in the weights array.</p></li> <li><p><strong>out_axis</strong> – axis or sequence of axes of the output dimension in the weights array.</p></li> <li><p><strong>batch_axis</strong> – axis or sequence of axes in the weight array that should be ignored.</p></li> <li><p><strong>dtype</strong> – the dtype of the weights.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>An initializer.</p> </dd> </dl> <p>Examples:</p> <div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">import</span> <span class="nn">jax</span><span class="o">,</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span> <span class="gp">&gt;&gt;&gt; </span><span class="n">initializer</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">initializers</span><span class="o">.</span><span class="n">he_normal</span><span class="p">()</span> <span class="gp">&gt;&gt;&gt; </span><span class="n">initializer</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">42</span><span class="p">),</span> <span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">jnp</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="go">Array([[ 0.6604483 , 1.1900088 , 1.2047218 ],</span> <span class="go"> [-0.87225807, -0.95321447, 0.1369438 ]], dtype=float32)</span> </pre></div> </div> </dd></dl> <dl class="py function"> <dt class="sig sig-object py" id="flax.linen.initializers.kaiming_uniform"> <span class="sig-prename descclassname"><span class="pre">flax.linen.initializers.</span></span><span class="sig-name descname"><span class="pre">kaiming_uniform</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">in_axis=-2</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">out_axis=-1</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch_axis=()</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype=&lt;class</span> <span class="pre">'jax.numpy.float64'&gt;</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#flax.linen.initializers.kaiming_uniform" title="Permalink to this definition">#</a></dt> <dd><p>Builds a He uniform initializer (aka Kaiming uniform initializer).</p> <p>A <a class="reference external" href="https://arxiv.org/abs/1502.01852">He uniform initializer</a> is a specialization of <code class="xref py py-func docutils literal notranslate"><span class="pre">jax.nn.initializers.variance_scaling()</span></code> where <code class="docutils literal notranslate"><span class="pre">scale</span> <span class="pre">=</span> <span class="pre">2.0</span></code>, <code class="docutils literal notranslate"><span class="pre">mode=&quot;fan_in&quot;</span></code>, and <code class="docutils literal notranslate"><span class="pre">distribution=&quot;uniform&quot;</span></code>.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>in_axis</strong> – axis or sequence of axes of the input dimension in the weights array.</p></li> <li><p><strong>out_axis</strong> – axis or sequence of axes of the output dimension in the weights array.</p></li> <li><p><strong>batch_axis</strong> – axis or sequence of axes in the weight array that should be ignored.</p></li> <li><p><strong>dtype</strong> – the dtype of the weights.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>An initializer.</p> </dd> </dl> <p>Examples:</p> <div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">import</span> <span class="nn">jax</span><span class="o">,</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span> <span class="gp">&gt;&gt;&gt; </span><span class="n">initializer</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">initializers</span><span class="o">.</span><span class="n">he_uniform</span><span class="p">()</span> <span class="gp">&gt;&gt;&gt; </span><span class="n">initializer</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">42</span><span class="p">),</span> <span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">jnp</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="go">Array([[ 0.79611576, 1.2789248 , 1.2896855 ],</span> <span class="go"> [-1.0108745 , -1.0855657 , 0.17398663]], dtype=float32)</span> </pre></div> </div> </dd></dl> <dl class="py function"> <dt class="sig sig-object py" id="flax.linen.initializers.lecun_normal"> <span class="sig-prename descclassname"><span class="pre">flax.linen.initializers.</span></span><span class="sig-name descname"><span class="pre">lecun_normal</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">in_axis=-2</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">out_axis=-1</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch_axis=()</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype=&lt;class</span> <span class="pre">'jax.numpy.float64'&gt;</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#flax.linen.initializers.lecun_normal" title="Permalink to this definition">#</a></dt> <dd><p>Builds a Lecun normal initializer.</p> <p>A <a class="reference external" href="https://arxiv.org/abs/1706.02515">Lecun normal initializer</a> is a specialization of <code class="xref py py-func docutils literal notranslate"><span class="pre">jax.nn.initializers.variance_scaling()</span></code> where <code class="docutils literal notranslate"><span class="pre">scale</span> <span class="pre">=</span> <span class="pre">1.0</span></code>, <code class="docutils literal notranslate"><span class="pre">mode=&quot;fan_in&quot;</span></code>, and <code class="docutils literal notranslate"><span class="pre">distribution=&quot;truncated_normal&quot;</span></code>.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>in_axis</strong> – axis or sequence of axes of the input dimension in the weights array.</p></li> <li><p><strong>out_axis</strong> – axis or sequence of axes of the output dimension in the weights array.</p></li> <li><p><strong>batch_axis</strong> – axis or sequence of axes in the weight array that should be ignored.</p></li> <li><p><strong>dtype</strong> – the dtype of the weights.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>An initializer.</p> </dd> </dl> <p>Examples:</p> <div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">import</span> <span class="nn">jax</span><span class="o">,</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span> <span class="gp">&gt;&gt;&gt; </span><span class="n">initializer</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">initializers</span><span class="o">.</span><span class="n">lecun_normal</span><span class="p">()</span> <span class="gp">&gt;&gt;&gt; </span><span class="n">initializer</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">42</span><span class="p">),</span> <span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">jnp</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="go">Array([[ 0.46700746, 0.8414632 , 0.8518669 ],</span> <span class="go"> [-0.61677957, -0.67402434, 0.09683388]], dtype=float32)</span> </pre></div> </div> </dd></dl> <dl class="py function"> <dt class="sig sig-object py" id="flax.linen.initializers.lecun_uniform"> <span class="sig-prename descclassname"><span class="pre">flax.linen.initializers.</span></span><span class="sig-name descname"><span class="pre">lecun_uniform</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">in_axis=-2</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">out_axis=-1</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch_axis=()</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype=&lt;class</span> <span class="pre">'jax.numpy.float64'&gt;</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#flax.linen.initializers.lecun_uniform" title="Permalink to this definition">#</a></dt> <dd><p>Builds a Lecun uniform initializer.</p> <p>A <a class="reference external" href="https://arxiv.org/abs/1706.02515">Lecun uniform initializer</a> is a specialization of <code class="xref py py-func docutils literal notranslate"><span class="pre">jax.nn.initializers.variance_scaling()</span></code> where <code class="docutils literal notranslate"><span class="pre">scale</span> <span class="pre">=</span> <span class="pre">1.0</span></code>, <code class="docutils literal notranslate"><span class="pre">mode=&quot;fan_in&quot;</span></code>, and <code class="docutils literal notranslate"><span class="pre">distribution=&quot;uniform&quot;</span></code>.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>in_axis</strong> – axis or sequence of axes of the input dimension in the weights array.</p></li> <li><p><strong>out_axis</strong> – axis or sequence of axes of the output dimension in the weights array.</p></li> <li><p><strong>batch_axis</strong> – axis or sequence of axes in the weight array that should be ignored.</p></li> <li><p><strong>dtype</strong> – the dtype of the weights.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>An initializer.</p> </dd> </dl> <p>Examples:</p> <div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">import</span> <span class="nn">jax</span><span class="o">,</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span> <span class="gp">&gt;&gt;&gt; </span><span class="n">initializer</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">initializers</span><span class="o">.</span><span class="n">lecun_uniform</span><span class="p">()</span> <span class="gp">&gt;&gt;&gt; </span><span class="n">initializer</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">42</span><span class="p">),</span> <span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">jnp</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="go">Array([[ 0.56293887, 0.90433645, 0.9119454 ],</span> <span class="go"> [-0.71479625, -0.7676109 , 0.12302713]], dtype=float32)</span> </pre></div> </div> </dd></dl> <dl class="py function"> <dt class="sig sig-object py" id="flax.linen.initializers.normal"> <span class="sig-prename descclassname"><span class="pre">flax.linen.initializers.</span></span><span class="sig-name descname"><span class="pre">normal</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">stddev=0.01</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype=&lt;class</span> <span class="pre">'jax.numpy.float64'&gt;</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#flax.linen.initializers.normal" title="Permalink to this definition">#</a></dt> <dd><p>Builds an initializer that returns real normally-distributed random arrays.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>stddev</strong> – optional; the standard deviation of the distribution.</p></li> <li><p><strong>dtype</strong> – optional; the initializer’s default dtype.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>An initializer that returns arrays whose values are normally distributed with mean <code class="docutils literal notranslate"><span class="pre">0</span></code> and standard deviation <code class="docutils literal notranslate"><span class="pre">stddev</span></code>.</p> </dd> </dl> <div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">import</span> <span class="nn">jax</span><span class="o">,</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span> <span class="gp">&gt;&gt;&gt; </span><span class="n">initializer</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">initializers</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="mf">5.0</span><span class="p">)</span> <span class="gp">&gt;&gt;&gt; </span><span class="n">initializer</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">42</span><span class="p">),</span> <span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">jnp</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="go">Array([[ 3.0613258 , 5.6129413 , 5.6866574 ],</span> <span class="go"> [-4.063663 , -4.4520254 , 0.63115686]], dtype=float32)</span> </pre></div> </div> </dd></dl> <dl class="py function"> <dt class="sig sig-object py" id="flax.linen.initializers.truncated_normal"> <span class="sig-prename descclassname"><span class="pre">flax.linen.initializers.</span></span><span class="sig-name descname"><span class="pre">truncated_normal</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">stddev=0.01</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype=&lt;class</span> <span class="pre">'jax.numpy.float64'&gt;</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">lower=-2.0</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">upper=2.0</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#flax.linen.initializers.truncated_normal" title="Permalink to this definition">#</a></dt> <dd><p>Builds an initializer that returns truncated-normal random arrays.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>stddev</strong> – optional; the standard deviation of the untruncated distribution. Note that this function does not apply the stddev correction as is done in the variancescaling initializers, and users are expected to apply this correction themselves via the stddev arg if they wish to employ it.</p></li> <li><p><strong>dtype</strong> – optional; the initializer’s default dtype.</p></li> <li><p><strong>lower</strong> – Float representing the lower bound for truncation. Applied before the output is multiplied by the stddev.</p></li> <li><p><strong>upper</strong> – Float representing the upper bound for truncation. Applied before the output is multiplied by the stddev.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>An initializer that returns arrays whose values follow the truncated normal distribution with mean <code class="docutils literal notranslate"><span class="pre">0</span></code> and standard deviation <code class="docutils literal notranslate"><span class="pre">stddev</span></code>, and range <span class="math notranslate nohighlight">\(\rm{lower * stddev} &lt; x &lt; \rm{upper * stddev}\)</span>.</p> </dd> </dl> <div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">import</span> <span class="nn">jax</span><span class="o">,</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span> <span class="gp">&gt;&gt;&gt; </span><span class="n">initializer</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">initializers</span><span class="o">.</span><span class="n">truncated_normal</span><span class="p">(</span><span class="mf">5.0</span><span class="p">)</span> <span class="gp">&gt;&gt;&gt; </span><span class="n">initializer</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">42</span><span class="p">),</span> <span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">jnp</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="go">Array([[ 2.9047365, 5.2338114, 5.29852 ],</span> <span class="go"> [-3.836303 , -4.192359 , 0.6022964]], dtype=float32)</span> </pre></div> </div> </dd></dl> <dl class="py function"> <dt class="sig sig-object py" id="flax.linen.initializers.ones"> <span class="sig-prename descclassname"><span class="pre">flax.linen.initializers.</span></span><span class="sig-name descname"><span class="pre">ones</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">key</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">shape</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype=&lt;class</span> <span class="pre">'jax.numpy.float64'&gt;</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#flax.linen.initializers.ones" title="Permalink to this definition">#</a></dt> <dd><p>An initializer that returns a constant array full of ones.</p> <p>The <code class="docutils literal notranslate"><span class="pre">key</span></code> argument is ignored.</p> <div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">import</span> <span class="nn">jax</span><span class="o">,</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span> <span class="gp">&gt;&gt;&gt; </span><span class="n">jax</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">initializers</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">42</span><span class="p">),</span> <span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span> <span class="n">jnp</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="go">Array([[1., 1.],</span> <span class="go"> [1., 1.],</span> <span class="go"> [1., 1.]], dtype=float32)</span> </pre></div> </div> </dd></dl> <dl class="py function"> <dt class="sig sig-object py" id="flax.linen.initializers.ones_init"> <span class="sig-prename descclassname"><span class="pre">flax.linen.initializers.</span></span><span class="sig-name descname"><span class="pre">ones_init</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/initializers.html#ones_init"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.initializers.ones_init" title="Permalink to this definition">#</a></dt> <dd><p>Builds an initializer that returns a constant array full of ones.</p> <div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">import</span> <span class="nn">jax</span><span class="o">,</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span> <span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">flax.linen.initializers</span> <span class="kn">import</span> <span class="n">ones_init</span> <span class="gp">&gt;&gt;&gt; </span><span class="n">ones_initializer</span> <span class="o">=</span> <span class="n">ones_init</span><span class="p">()</span> <span class="gp">&gt;&gt;&gt; </span><span class="n">ones_initializer</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">42</span><span class="p">),</span> <span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span> <span class="n">jnp</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="go">Array([[1., 1.],</span> <span class="go"> [1., 1.],</span> <span class="go"> [1., 1.]], dtype=float32)</span> </pre></div> </div> </dd></dl> <dl class="py function"> <dt class="sig sig-object py" id="flax.linen.initializers.orthogonal"> <span class="sig-prename descclassname"><span class="pre">flax.linen.initializers.</span></span><span class="sig-name descname"><span class="pre">orthogonal</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">scale=1.0</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">column_axis=-1</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype=&lt;class</span> <span class="pre">'jax.numpy.float64'&gt;</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#flax.linen.initializers.orthogonal" title="Permalink to this definition">#</a></dt> <dd><p>Builds an initializer that returns uniformly distributed orthogonal matrices.</p> <p>If the shape is not square, the matrices will have orthonormal rows or columns depending on which side is smaller.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>scale</strong> – the upper bound of the uniform distribution.</p></li> <li><p><strong>column_axis</strong> – the axis that contains the columns that should be orthogonal.</p></li> <li><p><strong>dtype</strong> – the default dtype of the weights.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>An orthogonal initializer.</p> </dd> </dl> <p>Examples:</p> <div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">import</span> <span class="nn">jax</span><span class="o">,</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span> <span class="gp">&gt;&gt;&gt; </span><span class="n">initializer</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">initializers</span><span class="o">.</span><span class="n">orthogonal</span><span class="p">()</span> <span class="gp">&gt;&gt;&gt; </span><span class="n">initializer</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">42</span><span class="p">),</span> <span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">jnp</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="go">Array([[ 3.9026976e-01, 7.2495741e-01, -5.6756169e-01],</span> <span class="go"> [ 8.8047469e-01, -4.7409311e-01, -1.3157725e-04]], dtype=float32)</span> </pre></div> </div> </dd></dl> <dl class="py function"> <dt class="sig sig-object py" id="flax.linen.initializers.uniform"> <span class="sig-prename descclassname"><span class="pre">flax.linen.initializers.</span></span><span class="sig-name descname"><span class="pre">uniform</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">scale=0.01</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype=&lt;class</span> <span class="pre">'jax.numpy.float64'&gt;</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#flax.linen.initializers.uniform" title="Permalink to this definition">#</a></dt> <dd><p>Builds an initializer that returns real uniformly-distributed random arrays.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>scale</strong> – optional; the upper bound of the random distribution.</p></li> <li><p><strong>dtype</strong> – optional; the initializer’s default dtype.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>An initializer that returns arrays whose values are uniformly distributed in the range <code class="docutils literal notranslate"><span class="pre">[0,</span> <span class="pre">scale)</span></code>.</p> </dd> </dl> <div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">import</span> <span class="nn">jax</span><span class="o">,</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span> <span class="gp">&gt;&gt;&gt; </span><span class="n">initializer</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">initializers</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="mf">10.0</span><span class="p">)</span> <span class="gp">&gt;&gt;&gt; </span><span class="n">initializer</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">42</span><span class="p">),</span> <span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">jnp</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="go">Array([[7.298188 , 8.691938 , 8.7230015],</span> <span class="go"> [2.0818567, 1.8662417, 5.5022564]], dtype=float32)</span> </pre></div> </div> </dd></dl> <dl class="py function"> <dt class="sig sig-object py" id="flax.linen.initializers.variance_scaling"> <span class="sig-prename descclassname"><span class="pre">flax.linen.initializers.</span></span><span class="sig-name descname"><span class="pre">variance_scaling</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">scale</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">mode</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">distribution</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">in_axis=-2</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">out_axis=-1</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch_axis=()</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype=&lt;class</span> <span class="pre">'jax.numpy.float64'&gt;</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#flax.linen.initializers.variance_scaling" title="Permalink to this definition">#</a></dt> <dd><p>Initializer that adapts its scale to the shape of the weights tensor.</p> <p>With <code class="docutils literal notranslate"><span class="pre">distribution=&quot;truncated_normal&quot;</span></code> or <code class="docutils literal notranslate"><span class="pre">distribution=&quot;normal&quot;</span></code>, samples are drawn from a (truncated) normal distribution with a mean of zero and a standard deviation (after truncation, if applicable) of <span class="math notranslate nohighlight">\(\sqrt{\frac{scale}{n}}\)</span>, where <cite>n</cite> is:</p> <ul class="simple"> <li><p>the number of input units in the weights tensor, if <code class="docutils literal notranslate"><span class="pre">mode=&quot;fan_in&quot;</span></code>,</p></li> <li><p>the number of output units, if <code class="docutils literal notranslate"><span class="pre">mode=&quot;fan_out&quot;</span></code>, or</p></li> <li><p>the average of the numbers of input and output units, if <code class="docutils literal notranslate"><span class="pre">mode=&quot;fan_avg&quot;</span></code>.</p></li> </ul> <p>This initializer can be configured with <code class="docutils literal notranslate"><span class="pre">in_axis</span></code>, <code class="docutils literal notranslate"><span class="pre">out_axis</span></code>, and <code class="docutils literal notranslate"><span class="pre">batch_axis</span></code> to work with general convolutional or dense layers; axes that are not in any of those arguments are assumed to be the “receptive field” (convolution kernel spatial axes).</p> <p>With <code class="docutils literal notranslate"><span class="pre">distribution=&quot;truncated_normal&quot;</span></code>, the absolute values of the samples are truncated at 2 standard deviations before scaling.</p> <p>With <code class="docutils literal notranslate"><span class="pre">distribution=&quot;uniform&quot;</span></code>, samples are drawn from:</p> <ul class="simple"> <li><p>a uniform interval, if <cite>dtype</cite> is real, or</p></li> <li><p>a uniform disk, if <cite>dtype</cite> is complex,</p></li> </ul> <p>with a mean of zero and a standard deviation of <span class="math notranslate nohighlight">\(\sqrt{\frac{scale}{n}}\)</span> where <cite>n</cite> is defined above.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>scale</strong> – scaling factor (positive float).</p></li> <li><p><strong>mode</strong> – one of <code class="docutils literal notranslate"><span class="pre">&quot;fan_in&quot;</span></code>, <code class="docutils literal notranslate"><span class="pre">&quot;fan_out&quot;</span></code>, and <code class="docutils literal notranslate"><span class="pre">&quot;fan_avg&quot;</span></code>.</p></li> <li><p><strong>distribution</strong> – random distribution to use. One of <code class="docutils literal notranslate"><span class="pre">&quot;truncated_normal&quot;</span></code>, <code class="docutils literal notranslate"><span class="pre">&quot;normal&quot;</span></code> and <code class="docutils literal notranslate"><span class="pre">&quot;uniform&quot;</span></code>.</p></li> <li><p><strong>in_axis</strong> – axis or sequence of axes of the input dimension in the weights array.</p></li> <li><p><strong>out_axis</strong> – axis or sequence of axes of the output dimension in the weights array.</p></li> <li><p><strong>batch_axis</strong> – axis or sequence of axes in the weight array that should be ignored.</p></li> <li><p><strong>dtype</strong> – the dtype of the weights.</p></li> </ul> </dd> </dl> </dd></dl> <dl class="py function"> <dt class="sig sig-object py" id="flax.linen.initializers.xavier_normal"> <span class="sig-prename descclassname"><span class="pre">flax.linen.initializers.</span></span><span class="sig-name descname"><span class="pre">xavier_normal</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">in_axis=-2</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">out_axis=-1</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch_axis=()</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype=&lt;class</span> <span class="pre">'jax.numpy.float64'&gt;</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#flax.linen.initializers.xavier_normal" title="Permalink to this definition">#</a></dt> <dd><p>Builds a Glorot normal initializer (aka Xavier normal initializer).</p> <p>A <a class="reference external" href="http://proceedings.mlr.press/v9/glorot10a.html">Glorot normal initializer</a> is a specialization of <code class="xref py py-func docutils literal notranslate"><span class="pre">jax.nn.initializers.variance_scaling()</span></code> where <code class="docutils literal notranslate"><span class="pre">scale</span> <span class="pre">=</span> <span class="pre">1.0</span></code>, <code class="docutils literal notranslate"><span class="pre">mode=&quot;fan_avg&quot;</span></code>, and <code class="docutils literal notranslate"><span class="pre">distribution=&quot;truncated_normal&quot;</span></code>.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>in_axis</strong> – axis or sequence of axes of the input dimension in the weights array.</p></li> <li><p><strong>out_axis</strong> – axis or sequence of axes of the output dimension in the weights array.</p></li> <li><p><strong>batch_axis</strong> – axis or sequence of axes in the weight array that should be ignored.</p></li> <li><p><strong>dtype</strong> – the dtype of the weights.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>An initializer.</p> </dd> </dl> <p>Examples:</p> <div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">import</span> <span class="nn">jax</span><span class="o">,</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span> <span class="gp">&gt;&gt;&gt; </span><span class="n">initializer</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">initializers</span><span class="o">.</span><span class="n">glorot_normal</span><span class="p">()</span> <span class="gp">&gt;&gt;&gt; </span><span class="n">initializer</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">42</span><span class="p">),</span> <span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">jnp</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="go">Array([[ 0.41770416, 0.75262755, 0.7619329 ],</span> <span class="go"> [-0.5516644 , -0.6028657 , 0.08661086]], dtype=float32)</span> </pre></div> </div> </dd></dl> <dl class="py function"> <dt class="sig sig-object py" id="flax.linen.initializers.xavier_uniform"> <span class="sig-prename descclassname"><span class="pre">flax.linen.initializers.</span></span><span class="sig-name descname"><span class="pre">xavier_uniform</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">in_axis=-2</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">out_axis=-1</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">batch_axis=()</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype=&lt;class</span> <span class="pre">'jax.numpy.float64'&gt;</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#flax.linen.initializers.xavier_uniform" title="Permalink to this definition">#</a></dt> <dd><p>Builds a Glorot uniform initializer (aka Xavier uniform initializer).</p> <p>A <a class="reference external" href="http://proceedings.mlr.press/v9/glorot10a.html">Glorot uniform initializer</a> is a specialization of <code class="xref py py-func docutils literal notranslate"><span class="pre">jax.nn.initializers.variance_scaling()</span></code> where <code class="docutils literal notranslate"><span class="pre">scale</span> <span class="pre">=</span> <span class="pre">1.0</span></code>, <code class="docutils literal notranslate"><span class="pre">mode=&quot;fan_avg&quot;</span></code>, and <code class="docutils literal notranslate"><span class="pre">distribution=&quot;uniform&quot;</span></code>.</p> <dl class="field-list simple"> <dt class="field-odd">Parameters</dt> <dd class="field-odd"><ul class="simple"> <li><p><strong>in_axis</strong> – axis or sequence of axes of the input dimension in the weights array.</p></li> <li><p><strong>out_axis</strong> – axis or sequence of axes of the output dimension in the weights array.</p></li> <li><p><strong>batch_axis</strong> – axis or sequence of axes in the weight array that should be ignored.</p></li> <li><p><strong>dtype</strong> – the dtype of the weights.</p></li> </ul> </dd> <dt class="field-even">Returns</dt> <dd class="field-even"><p>An initializer.</p> </dd> </dl> <p>Examples:</p> <div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">import</span> <span class="nn">jax</span><span class="o">,</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span> <span class="gp">&gt;&gt;&gt; </span><span class="n">initializer</span> <span class="o">=</span> <span class="n">jax</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">initializers</span><span class="o">.</span><span class="n">glorot_uniform</span><span class="p">()</span> <span class="gp">&gt;&gt;&gt; </span><span class="n">initializer</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">42</span><span class="p">),</span> <span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">jnp</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="go">Array([[ 0.50350785, 0.8088631 , 0.81566876],</span> <span class="go"> [-0.6393332 , -0.6865721 , 0.11003882]], dtype=float32)</span> </pre></div> </div> </dd></dl> <dl class="py function"> <dt class="sig sig-object py" id="flax.linen.initializers.zeros"> <span class="sig-prename descclassname"><span class="pre">flax.linen.initializers.</span></span><span class="sig-name descname"><span class="pre">zeros</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">key</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">shape</span></span></em>, <em class="sig-param"><span class="n"><span class="pre">dtype=&lt;class</span> <span class="pre">'jax.numpy.float64'&gt;</span></span></em><span class="sig-paren">)</span><a class="headerlink" href="#flax.linen.initializers.zeros" title="Permalink to this definition">#</a></dt> <dd><p>An initializer that returns a constant array full of zeros.</p> <p>The <code class="docutils literal notranslate"><span class="pre">key</span></code> argument is ignored.</p> <div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">import</span> <span class="nn">jax</span><span class="o">,</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span> <span class="gp">&gt;&gt;&gt; </span><span class="n">jax</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">initializers</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">42</span><span class="p">),</span> <span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">jnp</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="go">Array([[0., 0., 0.],</span> <span class="go"> [0., 0., 0.]], dtype=float32)</span> </pre></div> </div> </dd></dl> <dl class="py function"> <dt class="sig sig-object py" id="flax.linen.initializers.zeros_init"> <span class="sig-prename descclassname"><span class="pre">flax.linen.initializers.</span></span><span class="sig-name descname"><span class="pre">zeros_init</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="../../_modules/flax/linen/initializers.html#zeros_init"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#flax.linen.initializers.zeros_init" title="Permalink to this definition">#</a></dt> <dd><p>Builds an initializer that returns a constant array full of zeros.</p> <div class="doctest highlight-default notranslate"><div class="highlight"><pre><span></span><span class="gp">&gt;&gt;&gt; </span><span class="kn">import</span> <span class="nn">jax</span><span class="o">,</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="nn">jnp</span> <span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">flax.linen.initializers</span> <span class="kn">import</span> <span class="n">zeros_init</span> <span class="gp">&gt;&gt;&gt; </span><span class="n">zeros_initializer</span> <span class="o">=</span> <span class="n">zeros_init</span><span class="p">()</span> <span class="gp">&gt;&gt;&gt; </span><span class="n">zeros_initializer</span><span class="p">(</span><span class="n">jax</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">42</span><span class="p">),</span> <span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">jnp</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="go">Array([[0., 0., 0.],</span> <span class="go"> [0., 0., 0.]], dtype=float32)</span> </pre></div> </div> </dd></dl> </div> </article> <footer class="prev-next-footer d-print-none"> <div class="prev-next-area"> <a class="left-prev" href="activation_functions.html" title="previous page"> <i class="fa-solid fa-angle-left"></i> <div class="prev-next-info"> <p class="prev-next-subtitle">previous</p> <p class="prev-next-title">Activation functions</p> </div> </a> <a class="right-next" href="transformations.html" title="next page"> <div class="prev-next-info"> <p class="prev-next-subtitle">next</p> <p class="prev-next-title">Transformations</p> </div> <i class="fa-solid fa-angle-right"></i> </a> </div> </footer> </div> <div class="bd-sidebar-secondary bd-toc"><div class="sidebar-secondary-items sidebar-secondary__inner"> <div class="sidebar-secondary-item"> <div class="page-toc tocsection onthispage"> <i class="fa-solid fa-list"></i> Contents </div> <nav class="bd-toc-nav page-toc"> <ul class="visible nav section-nav flex-column"> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.initializers.constant"><code class="docutils literal notranslate"><span class="pre">constant()</span></code></a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.initializers.delta_orthogonal"><code class="docutils literal notranslate"><span class="pre">delta_orthogonal()</span></code></a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.initializers.glorot_normal"><code class="docutils literal notranslate"><span class="pre">glorot_normal()</span></code></a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.initializers.glorot_uniform"><code class="docutils literal notranslate"><span class="pre">glorot_uniform()</span></code></a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.initializers.he_normal"><code class="docutils literal notranslate"><span class="pre">he_normal()</span></code></a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.initializers.he_uniform"><code class="docutils literal notranslate"><span class="pre">he_uniform()</span></code></a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.initializers.kaiming_normal"><code class="docutils literal notranslate"><span class="pre">kaiming_normal()</span></code></a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.initializers.kaiming_uniform"><code class="docutils literal notranslate"><span class="pre">kaiming_uniform()</span></code></a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.initializers.lecun_normal"><code class="docutils literal notranslate"><span class="pre">lecun_normal()</span></code></a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.initializers.lecun_uniform"><code class="docutils literal notranslate"><span class="pre">lecun_uniform()</span></code></a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.initializers.normal"><code class="docutils literal notranslate"><span class="pre">normal()</span></code></a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.initializers.truncated_normal"><code class="docutils literal notranslate"><span class="pre">truncated_normal()</span></code></a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.initializers.ones"><code class="docutils literal notranslate"><span class="pre">ones()</span></code></a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.initializers.ones_init"><code class="docutils literal notranslate"><span class="pre">ones_init()</span></code></a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.initializers.orthogonal"><code class="docutils literal notranslate"><span class="pre">orthogonal()</span></code></a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.initializers.uniform"><code class="docutils literal notranslate"><span class="pre">uniform()</span></code></a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.initializers.variance_scaling"><code class="docutils literal notranslate"><span class="pre">variance_scaling()</span></code></a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.initializers.xavier_normal"><code class="docutils literal notranslate"><span class="pre">xavier_normal()</span></code></a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.initializers.xavier_uniform"><code class="docutils literal notranslate"><span class="pre">xavier_uniform()</span></code></a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.initializers.zeros"><code class="docutils literal notranslate"><span class="pre">zeros()</span></code></a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#flax.linen.initializers.zeros_init"><code class="docutils literal notranslate"><span class="pre">zeros_init()</span></code></a></li> </ul> </nav></div> </div></div> </div> <footer class="bd-footer-content"> <div class="bd-footer-content__inner container"> <div class="footer-item"> <p class="component-author"> By The Flax authors </p> </div> <div class="footer-item"> <p class="copyright"> © Copyright 2023, The Flax authors. <br/> </p> </div> <div class="footer-item"> </div> <div class="footer-item"> </div> </div> </footer> </main> </div> </div> <!-- Scripts loaded after <body> so the DOM is not blocked --> <script src="../../_static/scripts/bootstrap.js?digest=dfe6caa3a7d634c4db9b"></script> <script src="../../_static/scripts/pydata-sphinx-theme.js?digest=dfe6caa3a7d634c4db9b"></script> <footer class="bd-footer"> </footer> </body> </html>

Pages: 1 2 3 4 5 6 7 8 9 10