CINXE.COM

Convert PyTorch models to Flax

<!DOCTYPE html> <html lang="en" data-content_root="" > <head> <meta charset="utf-8" /> <meta name="viewport" content="width=device-width, initial-scale=1.0" /> <title>Convert PyTorch models to Flax</title> <script data-cfasync="false"> document.documentElement.dataset.mode = localStorage.getItem("mode") || ""; document.documentElement.dataset.theme = localStorage.getItem("theme") || ""; </script> <!-- Loaded before other Sphinx assets --> <link href="../../_static/styles/theme.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" /> <link href="../../_static/styles/bootstrap.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" /> <link href="../../_static/styles/pydata-sphinx-theme.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" /> <link href="../../_static/vendor/fontawesome/6.5.2/css/all.min.css?digest=dfe6caa3a7d634c4db9b" rel="stylesheet" /> <link rel="preload" as="font" type="font/woff2" crossorigin href="../../_static/vendor/fontawesome/6.5.2/webfonts/fa-solid-900.woff2" /> <link rel="preload" as="font" type="font/woff2" crossorigin href="../../_static/vendor/fontawesome/6.5.2/webfonts/fa-brands-400.woff2" /> <link rel="preload" as="font" type="font/woff2" crossorigin href="../../_static/vendor/fontawesome/6.5.2/webfonts/fa-regular-400.woff2" /> <link rel="stylesheet" type="text/css" href="../../_static/pygments.css" /> <link rel="stylesheet" type="text/css" href="../../_static/styles/sphinx-book-theme.css" /> <link rel="stylesheet" type="text/css" href="../../_static/mystnb.4510f1fc1dee50b3e5859aac5469c37c29e427902b24a333a5f9fcb2f0b3ac41.css" /> <link rel="stylesheet" type="text/css" href="../../_static/sphinx-design.5ea377869091fd0449014c60fc090103.min.css" /> <link rel="stylesheet" type="text/css" href="../../_static/css/flax_theme.css" /> <!-- Pre-loaded scripts that we'll load fully later --> <link rel="preload" as="script" href="../../_static/scripts/bootstrap.js?digest=dfe6caa3a7d634c4db9b" /> <link rel="preload" as="script" href="../../_static/scripts/pydata-sphinx-theme.js?digest=dfe6caa3a7d634c4db9b" /> <script src="../../_static/vendor/fontawesome/6.5.2/js/all.min.js?digest=dfe6caa3a7d634c4db9b"></script> <script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script> <script src="../../_static/jquery.js"></script> <script src="../../_static/underscore.js"></script> <script src="../../_static/_sphinx_javascript_frameworks_compat.js"></script> <script src="../../_static/doctools.js"></script> <script src="../../_static/sphinx_highlight.js"></script> <script src="../../_static/scripts/sphinx-book-theme.js"></script> <script src="../../_static/design-tabs.js"></script> <script>DOCUMENTATION_OPTIONS.pagename = 'guides/converting_and_upgrading/convert_pytorch_to_flax';</script> <link rel="shortcut icon" href="../../_static/flax.png"/> <link rel="index" title="Index" href="../../genindex.html" /> <link rel="search" title="Search" href="../../search.html" /> <link rel="next" title="Migrate checkpointing to Orbax" href="orbax_upgrade_guide.html" /> <link rel="prev" title="Migrating from Haiku to Flax" href="haiku_migration_guide.html" /> <meta name="viewport" content="width=device-width, initial-scale=1"/> <meta name="docsearch:language" content="en"/> <script async type="text/javascript" src="/_/static/javascript/readthedocs-addons.js"></script><meta name="readthedocs-project-slug" content="flax-linen" /><meta name="readthedocs-version-slug" content="latest" /><meta name="readthedocs-resolver-filename" content="/guides/converting_and_upgrading/convert_pytorch_to_flax.html" /><meta name="readthedocs-http-status" content="200" /></head> <body data-bs-spy="scroll" data-bs-target=".bd-toc-nav" data-offset="180" data-bs-root-margin="0px 0px -60%" data-default-mode=""> <div id="pst-skip-link" class="skip-link d-print-none"><a href="#main-content">Skip to main content</a></div> <div id="pst-scroll-pixel-helper"></div> <button type="button" class="btn rounded-pill" id="pst-back-to-top"> <i class="fa-solid fa-arrow-up"></i>Back to top</button> <input type="checkbox" class="sidebar-toggle" id="pst-primary-sidebar-checkbox"/> <label class="overlay overlay-primary" for="pst-primary-sidebar-checkbox"></label> <input type="checkbox" class="sidebar-toggle" id="pst-secondary-sidebar-checkbox"/> <label class="overlay overlay-secondary" for="pst-secondary-sidebar-checkbox"></label> <div class="search-button__wrapper"> <div class="search-button__overlay"></div> <div class="search-button__search-container"> <form class="bd-search d-flex align-items-center" action="../../search.html" method="get"> <i class="fa-solid fa-magnifying-glass"></i> <input type="search" class="form-control" name="q" id="search-input" placeholder="Search..." aria-label="Search..." autocomplete="off" autocorrect="off" autocapitalize="off" spellcheck="false"/> <span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd>K</kbd></span> </form></div> </div> <div class="pst-async-banner-revealer d-none"> <aside id="bd-header-version-warning" class="d-none d-print-none" aria-label="Version warning"></aside> </div> <aside class="bd-header-announcement" aria-label="Announcement"> <div class="bd-header-announcement__content"> <a href="https://flax.readthedocs.io/en/latest/index.html" style="text-decoration: none; color: white;" > This site covers the old Flax Linen API. <span style="color: lightgray;">[Explore the new <b>Flax NNX</b> API ✨]</span> </a> </div> </aside> <header class="bd-header navbar navbar-expand-lg bd-navbar d-print-none"> </header> <div class="bd-container"> <div class="bd-container__inner bd-page-width"> <div class="bd-sidebar-primary bd-sidebar"> <div class="sidebar-header-items sidebar-primary__section"> </div> <div class="sidebar-primary-items__start sidebar-primary__section"> <div class="sidebar-primary-item"> <a class="navbar-brand logo" href="../../index.html"> <img src="../../_static/flax.png" class="logo__image only-light" alt=" - Home"/> <script>document.write(`<img src="../../_static/flax.png" class="logo__image only-dark" alt=" - Home"/>`);</script> </a></div> <div class="sidebar-primary-item"> <script> document.write(` <button class="btn search-button-field search-button__button" title="Search" aria-label="Search" data-bs-placement="bottom" data-bs-toggle="tooltip"> <i class="fa-solid fa-magnifying-glass"></i> <span class="search-button__default-text">Search</span> <span class="search-button__kbd-shortcut"><kbd class="kbd-shortcut__modifier">Ctrl</kbd>+<kbd class="kbd-shortcut__modifier">K</kbd></span> </button> `); </script></div> <div class="sidebar-primary-item"><nav class="bd-links bd-docs-nav" aria-label="Main"> <div class="bd-toc-item navbar-nav active"> <ul class="current nav bd-sidenav"> <li class="toctree-l1"><a class="reference internal" href="../../quick_start.html">Quick start</a></li> <li class="toctree-l1"><a class="reference internal" href="../flax_fundamentals/flax_basics.html">Flax Basics</a></li> <li class="toctree-l1 current active has-children"><a class="reference internal" href="../index.html">Guides</a><details open="open"><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul class="current"> <li class="toctree-l2 has-children"><a class="reference internal" href="../flax_fundamentals/index.html">Flax fundamentals</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l3"><a class="reference external" href="https://jax.readthedocs.io/en/latest/jax-101/index.html">JAX 101</a></li> <li class="toctree-l3"><a class="reference internal" href="../flax_fundamentals/flax_basics.html">Flax Basics</a></li> <li class="toctree-l3"><a class="reference internal" href="../flax_fundamentals/state_params.html">Managing Parameters and State</a></li> <li class="toctree-l3"><a class="reference internal" href="../flax_fundamentals/setup_or_nncompact.html"><code class="docutils literal notranslate"><span class="pre">setup</span></code> vs <code class="docutils literal notranslate"><span class="pre">compact</span></code></a></li> <li class="toctree-l3"><a class="reference internal" href="../flax_fundamentals/arguments.html">Dealing with Flax Module arguments</a></li> <li class="toctree-l3"><a class="reference internal" href="../flax_fundamentals/rng_guide.html">Randomness and PRNGs in Flax</a></li> </ul> </details></li> <li class="toctree-l2 has-children"><a class="reference internal" href="../data_preprocessing/index.html">Data preprocessing</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l3"><a class="reference internal" href="../data_preprocessing/full_eval.html">Processing the entire Dataset</a></li> <li class="toctree-l3"><a class="reference internal" href="../data_preprocessing/loading_datasets.html">Loading datasets</a></li> </ul> </details></li> <li class="toctree-l2 has-children"><a class="reference internal" href="../training_techniques/index.html">Training techniques</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l3"><a class="reference internal" href="../training_techniques/batch_norm.html">Batch normalization</a></li> <li class="toctree-l3"><a class="reference internal" href="../training_techniques/dropout.html">Dropout</a></li> <li class="toctree-l3"><a class="reference internal" href="../training_techniques/lr_schedule.html">Learning rate scheduling</a></li> <li class="toctree-l3"><a class="reference internal" href="../training_techniques/transfer_learning.html">Transfer learning</a></li> <li class="toctree-l3"><a class="reference internal" href="../training_techniques/use_checkpointing.html">Save and load checkpoints</a></li> </ul> </details></li> <li class="toctree-l2 has-children"><a class="reference internal" href="../parallel_training/index.html">Parallel training</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l3"><a class="reference internal" href="../parallel_training/ensembling.html">Ensembling on multiple devices</a></li> <li class="toctree-l3"><a class="reference internal" href="../parallel_training/flax_on_pjit.html">Scale up Flax Modules on multiple devices</a></li> </ul> </details></li> <li class="toctree-l2 has-children"><a class="reference internal" href="../model_inspection/index.html">Model inspection</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l3"><a class="reference internal" href="../model_inspection/model_surgery.html">Model surgery</a></li> <li class="toctree-l3"><a class="reference internal" href="../model_inspection/extracting_intermediates.html">Extracting intermediate values</a></li> </ul> </details></li> <li class="toctree-l2 current active has-children"><a class="reference internal" href="index.html">Converting and upgrading</a><details open="open"><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul class="current"> <li class="toctree-l3"><a class="reference internal" href="haiku_migration_guide.html">Migrating from Haiku to Flax</a></li> <li class="toctree-l3 current active"><a class="current reference internal" href="#">Convert PyTorch models to Flax</a></li> <li class="toctree-l3"><a class="reference internal" href="orbax_upgrade_guide.html">Migrate checkpointing to Orbax</a></li> <li class="toctree-l3"><a class="reference internal" href="optax_update_guide.html">Upgrading my codebase to Optax</a></li> <li class="toctree-l3"><a class="reference internal" href="linen_upgrade_guide.html">Upgrading my codebase to Linen</a></li> <li class="toctree-l3"><a class="reference internal" href="rnncell_upgrade_guide.html">RNNCellBase Upgrade Guide</a></li> <li class="toctree-l3"><a class="reference internal" href="regular_dict_upgrade_guide.html">Migrate to regular dicts</a></li> </ul> </details></li> <li class="toctree-l2 has-children"><a class="reference internal" href="../quantization/index.html">Quantization</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l3"><a class="reference internal" href="../quantization/fp8_basics.html">User Guide on Using FP8</a></li> </ul> </details></li> <li class="toctree-l2"><a class="reference internal" href="../flax_sharp_bits.html">The Sharp Bits</a></li> </ul> </details></li> <li class="toctree-l1 has-children"><a class="reference internal" href="../../examples/index.html">Examples</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l2"><a class="reference internal" href="../../examples/core_examples.html">Core examples</a></li> <li class="toctree-l2"><a class="reference internal" href="../../examples/google_research_examples.html">Google Research examples</a></li> <li class="toctree-l2"><a class="reference internal" href="../../examples/repositories_that_use_flax.html">Repositories that use Flax</a></li> <li class="toctree-l2"><a class="reference internal" href="../../examples/community_examples.html">Community examples</a></li> </ul> </details></li> <li class="toctree-l1"><a class="reference internal" href="../../glossary.html">Glossary</a></li> <li class="toctree-l1"><a class="reference internal" href="../../faq.html">Frequently Asked Questions (FAQ)</a></li> <li class="toctree-l1 has-children"><a class="reference internal" href="../../developer_notes/index.html">Developer notes</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l2"><a class="reference internal" href="../../developer_notes/module_lifecycle.html">The Flax Module lifecycle</a></li> <li class="toctree-l2"><a class="reference internal" href="../../developer_notes/lift.html">Lifted transformations</a></li> <li class="toctree-l2"><a class="reference external" href="https://github.com/google/flax/tree/main/docs/flip">FLIPs</a></li> </ul> </details></li> <li class="toctree-l1"><a class="reference internal" href="../../philosophy.html">The Flax philosophy</a></li> <li class="toctree-l1"><a class="reference internal" href="../../contributing.html">How to contribute</a></li> <li class="toctree-l1 has-children"><a class="reference internal" href="../../api_reference/index.html">API Reference</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l2"><a class="reference internal" href="../../api_reference/flax.config.html">flax.config package</a></li> <li class="toctree-l2"><a class="reference internal" href="../../api_reference/flax.core.frozen_dict.html">flax.core.frozen_dict package</a></li> <li class="toctree-l2"><a class="reference internal" href="../../api_reference/flax.cursor.html">flax.cursor package</a></li> <li class="toctree-l2"><a class="reference internal" href="../../api_reference/flax.errors.html">flax.errors package</a></li> <li class="toctree-l2"><a class="reference internal" href="../../api_reference/flax.jax_utils.html">flax.jax_utils package</a></li> <li class="toctree-l2 has-children"><a class="reference internal" href="../../api_reference/flax.linen/index.html">flax.linen</a><details><summary><span class="toctree-toggle" role="presentation"><i class="fa-solid fa-chevron-down"></i></span></summary><ul> <li class="toctree-l3"><a class="reference internal" href="../../api_reference/flax.linen/module.html">Module</a></li> <li class="toctree-l3"><a class="reference internal" href="../../api_reference/flax.linen/init_apply.html">Init/Apply</a></li> <li class="toctree-l3"><a class="reference internal" href="../../api_reference/flax.linen/layers.html">Layers</a></li> <li class="toctree-l3"><a class="reference internal" href="../../api_reference/flax.linen/activation_functions.html">Activation functions</a></li> <li class="toctree-l3"><a class="reference internal" href="../../api_reference/flax.linen/initializers.html">Initializers</a></li> <li class="toctree-l3"><a class="reference internal" href="../../api_reference/flax.linen/transformations.html">Transformations</a></li> <li class="toctree-l3"><a class="reference internal" href="../../api_reference/flax.linen/inspection.html">Inspection</a></li> <li class="toctree-l3"><a class="reference internal" href="../../api_reference/flax.linen/variable.html">Variable dictionary</a></li> <li class="toctree-l3"><a class="reference internal" href="../../api_reference/flax.linen/spmd.html">SPMD</a></li> <li class="toctree-l3"><a class="reference internal" href="../../api_reference/flax.linen/decorators.html">Decorators</a></li> <li class="toctree-l3"><a class="reference internal" href="../../api_reference/flax.linen/profiling.html">Profiling</a></li> </ul> </details></li> <li class="toctree-l2"><a class="reference internal" href="../../api_reference/flax.serialization.html">flax.serialization package</a></li> <li class="toctree-l2"><a class="reference internal" href="../../api_reference/flax.struct.html">flax.struct package</a></li> <li class="toctree-l2"><a class="reference internal" href="../../api_reference/flax.traceback_util.html">flax.traceback_util package</a></li> <li class="toctree-l2"><a class="reference internal" href="../../api_reference/flax.training.html">flax.training package</a></li> <li class="toctree-l2"><a class="reference internal" href="../../api_reference/flax.traverse_util.html">flax.traverse_util package</a></li> </ul> </details></li> <li class="toctree-l1"><a class="reference external" href="https://flax.readthedocs.io/en/latest/index.html">Flax NNX</a></li> </ul> </div> </nav></div> </div> <div class="sidebar-primary-items__end sidebar-primary__section"> </div> <div id="rtd-footer-container"></div> </div> <main id="main-content" class="bd-main" role="main"> <div class="sbt-scroll-pixel-helper"></div> <div class="bd-content"> <div class="bd-article-container"> <div class="bd-header-article d-print-none"> <div class="header-article-items header-article__inner"> <div class="header-article-items__start"> <div class="header-article-item"><button class="sidebar-toggle primary-toggle btn btn-sm" title="Toggle primary sidebar" data-bs-placement="bottom" data-bs-toggle="tooltip"> <span class="fa-solid fa-bars"></span> </button></div> </div> <div class="header-article-items__end"> <div class="header-article-item"> <div class="article-header-buttons"> <a href="https://github.com/google/flax" target="_blank" class="btn btn-sm btn-source-repository-button" title="Source repository" data-bs-placement="bottom" data-bs-toggle="tooltip" > <span class="btn__icon-container"> <i class="fab fa-github"></i> </span> </a> <div class="dropdown dropdown-download-buttons"> <button class="btn dropdown-toggle" type="button" data-bs-toggle="dropdown" aria-expanded="false" aria-label="Download this page"> <i class="fas fa-download"></i> </button> <ul class="dropdown-menu"> <li><a href="../../_sources/guides/converting_and_upgrading/convert_pytorch_to_flax.rst" target="_blank" class="btn btn-sm btn-download-source-button dropdown-item" title="Download source file" data-bs-placement="left" data-bs-toggle="tooltip" > <span class="btn__icon-container"> <i class="fas fa-file"></i> </span> <span class="btn__text-container">.rst</span> </a> </li> <li> <button onclick="window.print()" class="btn btn-sm btn-download-pdf-button dropdown-item" title="Print to PDF" data-bs-placement="left" data-bs-toggle="tooltip" > <span class="btn__icon-container"> <i class="fas fa-file-pdf"></i> </span> <span class="btn__text-container">.pdf</span> </button> </li> </ul> </div> <button onclick="toggleFullScreen()" class="btn btn-sm btn-fullscreen-button" title="Fullscreen mode" data-bs-placement="bottom" data-bs-toggle="tooltip" > <span class="btn__icon-container"> <i class="fas fa-expand"></i> </span> </button> <script> document.write(` <button class="btn btn-sm nav-link pst-navbar-icon theme-switch-button" title="light/dark" aria-label="light/dark" data-bs-placement="bottom" data-bs-toggle="tooltip"> <i class="theme-switch fa-solid fa-sun fa-lg" data-mode="light"></i> <i class="theme-switch fa-solid fa-moon fa-lg" data-mode="dark"></i> <i class="theme-switch fa-solid fa-circle-half-stroke fa-lg" data-mode="auto"></i> </button> `); </script> <script> document.write(` <button class="btn btn-sm pst-navbar-icon search-button search-button__button" title="Search" aria-label="Search" data-bs-placement="bottom" data-bs-toggle="tooltip"> <i class="fa-solid fa-magnifying-glass fa-lg"></i> </button> `); </script> <button class="sidebar-toggle secondary-toggle btn btn-sm" title="Toggle secondary sidebar" data-bs-placement="bottom" data-bs-toggle="tooltip"> <span class="fa-solid fa-list"></span> </button> </div></div> </div> </div> </div> <div id="jb-print-docs-body" class="onlyprint"> <h1>Convert PyTorch models to Flax</h1> <!-- Table of contents --> <div id="print-main-content"> <div id="jb-print-toc"> <div> <h2> Contents </h2> </div> <nav aria-label="Page"> <ul class="visible nav section-nav flex-column"> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#fc-layers">FC Layers</a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#convolutions">Convolutions</a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#convolutions-and-fc-layers">Convolutions and FC Layers</a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#batch-norm">Batch Norm</a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#average-pooling">Average Pooling</a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#transposed-convolutions">Transposed Convolutions</a></li> </ul> </nav> </div> </div> </div> <div id="searchbox"></div> <article class="bd-article"> <div class="section" id="convert-pytorch-models-to-flax"> <h1>Convert PyTorch models to Flax<a class="headerlink" href="#convert-pytorch-models-to-flax" title="Permalink to this heading">#</a></h1> <p>We will show how to convert PyTorch models to Flax. We will cover convolutions, fc layers, batch norm, and average pooling.</p> <div class="section" id="fc-layers"> <h2>FC Layers<a class="headerlink" href="#fc-layers" title="Permalink to this heading">#</a></h2> <p>Let’s start with fc layers. The only thing to be aware of here is that the PyTorch kernel has shape [outC, inC] and the Flax kernel has shape [inC, outC]. Transposing the kernel will do the trick.</p> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">t_fc</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">in_features</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">out_features</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span> <span class="n">kernel</span> <span class="o">=</span> <span class="n">t_fc</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> <span class="n">bias</span> <span class="o">=</span> <span class="n">t_fc</span><span class="o">.</span><span class="n">bias</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> <span class="c1"># [outC, inC] -&gt; [inC, outC]</span> <span class="n">kernel</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">kernel</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span> <span class="n">key</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span> <span class="n">variables</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;params&#39;</span><span class="p">:</span> <span class="p">{</span><span class="s1">&#39;kernel&#39;</span><span class="p">:</span> <span class="n">kernel</span><span class="p">,</span> <span class="s1">&#39;bias&#39;</span><span class="p">:</span> <span class="n">bias</span><span class="p">}}</span> <span class="n">j_fc</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">features</span><span class="o">=</span><span class="mi">4</span><span class="p">)</span> <span class="n">j_out</span> <span class="o">=</span> <span class="n">j_fc</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="n">t_x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">from_numpy</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">x</span><span class="p">))</span> <span class="n">t_out</span> <span class="o">=</span> <span class="n">t_fc</span><span class="p">(</span><span class="n">t_x</span><span class="p">)</span> <span class="n">t_out</span> <span class="o">=</span> <span class="n">t_out</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> <span class="n">np</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_almost_equal</span><span class="p">(</span><span class="n">j_out</span><span class="p">,</span> <span class="n">t_out</span><span class="p">,</span> <span class="n">decimal</span><span class="o">=</span><span class="mi">6</span><span class="p">)</span> </pre></div> </div> </div> <div class="section" id="convolutions"> <h2>Convolutions<a class="headerlink" href="#convolutions" title="Permalink to this heading">#</a></h2> <p>Let’s now look at 2D convolutions. PyTorch uses the NCHW format and Flax uses NHWC. Consequently, the kernels will have different shapes. The kernel in PyTorch has shape [outC, inC, kH, kW] and the Flax kernel has shape [kH, kW, inC, outC]. Transposing the kernel will do the trick.</p> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">t_conv</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_channels</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">out_channels</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s1">&#39;valid&#39;</span><span class="p">)</span> <span class="n">kernel</span> <span class="o">=</span> <span class="n">t_conv</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> <span class="n">bias</span> <span class="o">=</span> <span class="n">t_conv</span><span class="o">.</span><span class="n">bias</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> <span class="c1"># [outC, inC, kH, kW] -&gt; [kH, kW, inC, outC]</span> <span class="n">kernel</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">kernel</span><span class="p">,</span> <span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span> <span class="n">key</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span> <span class="n">variables</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;params&#39;</span><span class="p">:</span> <span class="p">{</span><span class="s1">&#39;kernel&#39;</span><span class="p">:</span> <span class="n">kernel</span><span class="p">,</span> <span class="s1">&#39;bias&#39;</span><span class="p">:</span> <span class="n">bias</span><span class="p">}}</span> <span class="n">j_conv</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv</span><span class="p">(</span><span class="n">features</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span> <span class="n">padding</span><span class="o">=</span><span class="s1">&#39;valid&#39;</span><span class="p">)</span> <span class="n">j_out</span> <span class="o">=</span> <span class="n">j_conv</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="c1"># [N, H, W, C] -&gt; [N, C, H, W]</span> <span class="n">t_x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">from_numpy</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)))</span> <span class="n">t_out</span> <span class="o">=</span> <span class="n">t_conv</span><span class="p">(</span><span class="n">t_x</span><span class="p">)</span> <span class="c1"># [N, C, H, W] -&gt; [N, H, W, C]</span> <span class="n">t_out</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">t_out</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="n">np</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_almost_equal</span><span class="p">(</span><span class="n">j_out</span><span class="p">,</span> <span class="n">t_out</span><span class="p">,</span> <span class="n">decimal</span><span class="o">=</span><span class="mi">6</span><span class="p">)</span> </pre></div> </div> </div> <div class="section" id="convolutions-and-fc-layers"> <h2>Convolutions and FC Layers<a class="headerlink" href="#convolutions-and-fc-layers" title="Permalink to this heading">#</a></h2> <p>We have to be careful, when we have a model that uses convolutions followed by fc layers (ResNet, VGG, etc). In PyTorch, the activations will have shape [N, C, H, W] after the convolutions and are then reshaped to [N, C * H * W] before being fed to the fc layers. When we port our weights from PyTorch to Flax, the activations after the convolutions will be of shape [N, H, W, C] in Flax. Before we reshape the activations for the fc layers, we have to transpose them to [N, C, H, W].</p> <p>Consider this PyTorch model:</p> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span><span class="w"> </span><span class="nc">TModel</span><span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="nb">super</span><span class="p">(</span><span class="n">TModel</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Conv2d</span><span class="p">(</span><span class="n">in_channels</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">out_channels</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s1">&#39;valid&#39;</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">Linear</span><span class="p">(</span><span class="n">in_features</span><span class="o">=</span><span class="mi">100</span><span class="p">,</span> <span class="n">out_features</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span> <span class="k">def</span><span class="w"> </span><span class="nf">forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">x</span> <span class="n">t_model</span> <span class="o">=</span> <span class="n">TModel</span><span class="p">()</span> </pre></div> </div> <p>Now, if you want to use the weights from this model in Flax, the corresponding Flax model has to look like this:</p> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span><span class="w"> </span><span class="nc">JModel</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span> <span class="nd">@nn</span><span class="o">.</span><span class="n">compact</span> <span class="k">def</span><span class="w"> </span><span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Conv</span><span class="p">(</span><span class="n">features</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span> <span class="n">padding</span><span class="o">=</span><span class="s1">&#39;valid&#39;</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">&#39;conv&#39;</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># [N, H, W, C] -&gt; [N, C, H, W]</span> <span class="n">x</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">))</span> <span class="n">x</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">))</span> <span class="n">x</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">features</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">&#39;fc&#39;</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">x</span> <span class="n">j_model</span> <span class="o">=</span> <span class="n">JModel</span><span class="p">()</span> </pre></div> </div> <p>The model looks very similar to the PyTorch model, except that we included a transpose operation before reshaping our activations for the fc layer. We can omit the transpose operation if we apply pooling before reshaping such that the spatial dimensions are 1x1.</p> <p>Other than the transpose operation before reshaping, we can convert the weights the same way as we did before:</p> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">conv_kernel</span> <span class="o">=</span> <span class="n">t_model</span><span class="o">.</span><span class="n">state_dict</span><span class="p">()[</span><span class="s1">&#39;conv.weight&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> <span class="n">conv_bias</span> <span class="o">=</span> <span class="n">t_model</span><span class="o">.</span><span class="n">state_dict</span><span class="p">()[</span><span class="s1">&#39;conv.bias&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> <span class="n">fc_kernel</span> <span class="o">=</span> <span class="n">t_model</span><span class="o">.</span><span class="n">state_dict</span><span class="p">()[</span><span class="s1">&#39;fc.weight&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> <span class="n">fc_bias</span> <span class="o">=</span> <span class="n">t_model</span><span class="o">.</span><span class="n">state_dict</span><span class="p">()[</span><span class="s1">&#39;fc.bias&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> <span class="c1"># [outC, inC, kH, kW] -&gt; [kH, kW, inC, outC]</span> <span class="n">conv_kernel</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">conv_kernel</span><span class="p">,</span> <span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span> <span class="c1"># [outC, inC] -&gt; [inC, outC]</span> <span class="n">fc_kernel</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">fc_kernel</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span> <span class="n">variables</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;params&#39;</span><span class="p">:</span> <span class="p">{</span><span class="s1">&#39;conv&#39;</span><span class="p">:</span> <span class="p">{</span><span class="s1">&#39;kernel&#39;</span><span class="p">:</span> <span class="n">conv_kernel</span><span class="p">,</span> <span class="s1">&#39;bias&#39;</span><span class="p">:</span> <span class="n">conv_bias</span><span class="p">},</span> <span class="s1">&#39;fc&#39;</span><span class="p">:</span> <span class="p">{</span><span class="s1">&#39;kernel&#39;</span><span class="p">:</span> <span class="n">fc_kernel</span><span class="p">,</span> <span class="s1">&#39;bias&#39;</span><span class="p">:</span> <span class="n">fc_bias</span><span class="p">}}}</span> <span class="n">key</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span> <span class="n">j_out</span> <span class="o">=</span> <span class="n">j_model</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="c1"># [N, H, W, C] -&gt; [N, C, H, W]</span> <span class="n">t_x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">from_numpy</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)))</span> <span class="n">t_out</span> <span class="o">=</span> <span class="n">t_model</span><span class="p">(</span><span class="n">t_x</span><span class="p">)</span> <span class="n">t_out</span> <span class="o">=</span> <span class="n">t_out</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> <span class="n">np</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_almost_equal</span><span class="p">(</span><span class="n">j_out</span><span class="p">,</span> <span class="n">t_out</span><span class="p">,</span> <span class="n">decimal</span><span class="o">=</span><span class="mi">6</span><span class="p">)</span> </pre></div> </div> </div> <div class="section" id="batch-norm"> <h2>Batch Norm<a class="headerlink" href="#batch-norm" title="Permalink to this heading">#</a></h2> <p><code class="docutils literal notranslate"><span class="pre">torch.nn.BatchNorm2d</span></code> uses <code class="docutils literal notranslate"><span class="pre">0.1</span></code> as the default value for the <code class="docutils literal notranslate"><span class="pre">momentum</span></code> parameter while <a class="reference external" href="https://flax.readthedocs.io/en/latest/api_reference/flax.linen/layers.html#flax.linen.BatchNorm"><code class="docutils literal notranslate"><span class="pre">nn.BatchNorm</span></code></a> uses <code class="docutils literal notranslate"><span class="pre">0.9</span></code>. However, this corresponds to the same computation, because PyTorch multiplies the estimated statistic with <code class="docutils literal notranslate"><span class="pre">(1</span> <span class="pre">−</span> <span class="pre">momentum)</span></code> and the new observed value with <code class="docutils literal notranslate"><span class="pre">momentum</span></code>, while Flax multiplies the estimated statistic with <code class="docutils literal notranslate"><span class="pre">momentum</span></code> and the new observed value with <code class="docutils literal notranslate"><span class="pre">(1</span> <span class="pre">−</span> <span class="pre">momentum)</span></code>.</p> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">t_bn</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm2d</span><span class="p">(</span><span class="n">num_features</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">momentum</span><span class="o">=</span><span class="mf">0.1</span><span class="p">)</span> <span class="n">t_bn</span><span class="o">.</span><span class="n">eval</span><span class="p">()</span> <span class="n">scale</span> <span class="o">=</span> <span class="n">t_bn</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> <span class="n">bias</span> <span class="o">=</span> <span class="n">t_bn</span><span class="o">.</span><span class="n">bias</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> <span class="n">mean</span> <span class="o">=</span> <span class="n">t_bn</span><span class="o">.</span><span class="n">running_mean</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> <span class="n">var</span> <span class="o">=</span> <span class="n">t_bn</span><span class="o">.</span><span class="n">running_var</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> <span class="n">variables</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;params&#39;</span><span class="p">:</span> <span class="p">{</span><span class="s1">&#39;scale&#39;</span><span class="p">:</span> <span class="n">scale</span><span class="p">,</span> <span class="s1">&#39;bias&#39;</span><span class="p">:</span> <span class="n">bias</span><span class="p">},</span> <span class="s1">&#39;batch_stats&#39;</span><span class="p">:</span> <span class="p">{</span><span class="s1">&#39;mean&#39;</span><span class="p">:</span> <span class="n">mean</span><span class="p">,</span> <span class="s1">&#39;var&#39;</span><span class="p">:</span> <span class="n">var</span><span class="p">}}</span> <span class="n">key</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span> <span class="n">j_bn</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">BatchNorm</span><span class="p">(</span><span class="n">momentum</span><span class="o">=</span><span class="mf">0.9</span><span class="p">,</span> <span class="n">use_running_average</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="n">j_out</span> <span class="o">=</span> <span class="n">j_bn</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="c1"># [N, H, W, C] -&gt; [N, C, H, W]</span> <span class="n">t_x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">from_numpy</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)))</span> <span class="n">t_out</span> <span class="o">=</span> <span class="n">t_bn</span><span class="p">(</span><span class="n">t_x</span><span class="p">)</span> <span class="c1"># [N, C, H, W] -&gt; [N, H, W, C]</span> <span class="n">t_out</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">t_out</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="n">np</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_almost_equal</span><span class="p">(</span><span class="n">j_out</span><span class="p">,</span> <span class="n">t_out</span><span class="p">,</span> <span class="n">decimal</span><span class="o">=</span><span class="mi">6</span><span class="p">)</span> </pre></div> </div> </div> <div class="section" id="average-pooling"> <h2>Average Pooling<a class="headerlink" href="#average-pooling" title="Permalink to this heading">#</a></h2> <p><code class="docutils literal notranslate"><span class="pre">torch.nn.AvgPool2d</span></code> and <a class="reference external" href="https://flax.readthedocs.io/en/latest/api_reference/flax.linen/layers.html#flax.linen.avg_pool"><code class="docutils literal notranslate"><span class="pre">nn.avg_pool()</span></code></a> are compatible when using default parameters. However, <code class="docutils literal notranslate"><span class="pre">torch.nn.AvgPool2d</span></code> has a parameter <code class="docutils literal notranslate"><span class="pre">count_include_pad</span></code>. When <code class="docutils literal notranslate"><span class="pre">count_include_pad=False</span></code>, the zero-padding will not be considered for the average calculation. There does not exist a similar parameter for <a class="reference external" href="https://flax.readthedocs.io/en/latest/api_reference/flax.linen/layers.html#flax.linen.avg_pool"><code class="docutils literal notranslate"><span class="pre">nn.avg_pool()</span></code></a>. However, we can easily implement a wrapper around the pooling operation. <code class="docutils literal notranslate"><span class="pre">nn.pool()</span></code> is the core function behind <a class="reference external" href="https://flax.readthedocs.io/en/latest/api_reference/flax.linen/layers.html#flax.linen.avg_pool"><code class="docutils literal notranslate"><span class="pre">nn.avg_pool()</span></code></a> and <a class="reference external" href="https://flax.readthedocs.io/en/latest/api_reference/flax.linen/layers.html#flax.linen.max_pool"><code class="docutils literal notranslate"><span class="pre">nn.max_pool()</span></code></a>.</p> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">def</span><span class="w"> </span><span class="nf">avg_pool</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">window_shape</span><span class="p">,</span> <span class="n">strides</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s1">&#39;VALID&#39;</span><span class="p">):</span> <span class="w"> </span><span class="sd">&quot;&quot;&quot;</span> <span class="sd"> Pools the input by taking the average over a window.</span> <span class="sd"> In comparison to nn.avg_pool(), this pooling operation does not</span> <span class="sd"> consider the padded zero&#39;s for the average computation.</span> <span class="sd"> &quot;&quot;&quot;</span> <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">window_shape</span><span class="p">)</span> <span class="o">==</span> <span class="mi">2</span> <span class="n">y</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">pool</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="mf">0.</span><span class="p">,</span> <span class="n">jax</span><span class="o">.</span><span class="n">lax</span><span class="o">.</span><span class="n">add</span><span class="p">,</span> <span class="n">window_shape</span><span class="p">,</span> <span class="n">strides</span><span class="p">,</span> <span class="n">padding</span><span class="p">)</span> <span class="n">counts</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">pool</span><span class="p">(</span><span class="n">jnp</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">inputs</span><span class="p">),</span> <span class="mf">0.</span><span class="p">,</span> <span class="n">jax</span><span class="o">.</span><span class="n">lax</span><span class="o">.</span><span class="n">add</span><span class="p">,</span> <span class="n">window_shape</span><span class="p">,</span> <span class="n">strides</span><span class="p">,</span> <span class="n">padding</span><span class="p">)</span> <span class="n">y</span> <span class="o">=</span> <span class="n">y</span> <span class="o">/</span> <span class="n">counts</span> <span class="k">return</span> <span class="n">y</span> <span class="n">key</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span> <span class="n">j_out</span> <span class="o">=</span> <span class="n">avg_pool</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">window_shape</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span> <span class="n">strides</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">padding</span><span class="o">=</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)))</span> <span class="n">t_pool</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">AvgPool2d</span><span class="p">(</span><span class="n">kernel_size</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">count_include_pad</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> <span class="c1"># [N, H, W, C] -&gt; [N, C, H, W]</span> <span class="n">t_x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">from_numpy</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)))</span> <span class="n">t_out</span> <span class="o">=</span> <span class="n">t_pool</span><span class="p">(</span><span class="n">t_x</span><span class="p">)</span> <span class="c1"># [N, C, H, W] -&gt; [N, H, W, C]</span> <span class="n">t_out</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">t_out</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="n">np</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_almost_equal</span><span class="p">(</span><span class="n">j_out</span><span class="p">,</span> <span class="n">t_out</span><span class="p">,</span> <span class="n">decimal</span><span class="o">=</span><span class="mi">6</span><span class="p">)</span> </pre></div> </div> </div> <div class="section" id="transposed-convolutions"> <h2>Transposed Convolutions<a class="headerlink" href="#transposed-convolutions" title="Permalink to this heading">#</a></h2> <p><code class="docutils literal notranslate"><span class="pre">torch.nn.ConvTranspose2d</span></code> and <a class="reference external" href="https://flax.readthedocs.io/en/latest/api_reference/flax.linen/layers.html#flax.linen.ConvTranspose"><code class="docutils literal notranslate"><span class="pre">nn.ConvTranspose</span></code></a> are not compatible. <a class="reference external" href="https://flax.readthedocs.io/en/latest/api_reference/flax.linen/layers.html#flax.linen.ConvTranspose"><code class="docutils literal notranslate"><span class="pre">nn.ConvTranspose</span></code></a> is a wrapper around <a class="reference external" href="https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv_transpose.html"><code class="docutils literal notranslate"><span class="pre">jax.lax.conv_transpose</span></code></a> which computes a fractionally strided convolution, while <code class="docutils literal notranslate"><span class="pre">torch.nn.ConvTranspose2d</span></code> computes a gradient based transposed convolution. Currently, there is no implementation of a gradient based transposed convolution is <code class="docutils literal notranslate"><span class="pre">Jax</span></code>. However, there is a pending <a class="reference external" href="https://github.com/jax-ml/jax/pull/5772">pull request</a> that contains an implementation.</p> <p>To load <code class="docutils literal notranslate"><span class="pre">torch.nn.ConvTranspose2d</span></code> parameters into Flax, we need to use the <code class="docutils literal notranslate"><span class="pre">transpose_kernel</span></code> arg in Flax’s <code class="docutils literal notranslate"><span class="pre">nn.ConvTranspose</span></code> layer.</p> <div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="c1"># padding is inverted</span> <span class="n">torch_padding</span> <span class="o">=</span> <span class="mi">0</span> <span class="n">flax_padding</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">torch_padding</span> <span class="n">t_conv</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">ConvTranspose2d</span><span class="p">(</span><span class="n">in_channels</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">out_channels</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="n">torch_padding</span><span class="p">)</span> <span class="n">kernel</span> <span class="o">=</span> <span class="n">t_conv</span><span class="o">.</span><span class="n">weight</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> <span class="n">bias</span> <span class="o">=</span> <span class="n">t_conv</span><span class="o">.</span><span class="n">bias</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> <span class="c1"># [inC, outC, kH, kW] -&gt; [kH, kW, outC, inC]</span> <span class="n">kernel</span> <span class="o">=</span> <span class="n">jnp</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">kernel</span><span class="p">,</span> <span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">))</span> <span class="n">key</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">key</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span> <span class="n">variables</span> <span class="o">=</span> <span class="p">{</span><span class="s1">&#39;params&#39;</span><span class="p">:</span> <span class="p">{</span><span class="s1">&#39;kernel&#39;</span><span class="p">:</span> <span class="n">kernel</span><span class="p">,</span> <span class="s1">&#39;bias&#39;</span><span class="p">:</span> <span class="n">bias</span><span class="p">}}</span> <span class="c1"># ConvTranspose expects the kernel to be [kH, kW, inC, outC],</span> <span class="c1"># but with `transpose_kernel=True`, it expects [kH, kW, outC, inC] instead</span> <span class="n">j_conv</span> <span class="o">=</span> <span class="n">nn</span><span class="o">.</span><span class="n">ConvTranspose</span><span class="p">(</span><span class="n">features</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span> <span class="n">padding</span><span class="o">=</span><span class="n">flax_padding</span><span class="p">,</span> <span class="n">transpose_kernel</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="n">j_out</span> <span class="o">=</span> <span class="n">j_conv</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">variables</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="c1"># [N, H, W, C] -&gt; [N, C, H, W]</span> <span class="n">t_x</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">from_numpy</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">)))</span> <span class="n">t_out</span> <span class="o">=</span> <span class="n">t_conv</span><span class="p">(</span><span class="n">t_x</span><span class="p">)</span> <span class="c1"># [N, C, H, W] -&gt; [N, H, W, C]</span> <span class="n">t_out</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">t_out</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="n">np</span><span class="o">.</span><span class="n">testing</span><span class="o">.</span><span class="n">assert_almost_equal</span><span class="p">(</span><span class="n">j_out</span><span class="p">,</span> <span class="n">t_out</span><span class="p">,</span> <span class="n">decimal</span><span class="o">=</span><span class="mi">6</span><span class="p">)</span> </pre></div> </div> </div> </div> </article> <footer class="prev-next-footer d-print-none"> <div class="prev-next-area"> <a class="left-prev" href="haiku_migration_guide.html" title="previous page"> <i class="fa-solid fa-angle-left"></i> <div class="prev-next-info"> <p class="prev-next-subtitle">previous</p> <p class="prev-next-title">Migrating from Haiku to Flax</p> </div> </a> <a class="right-next" href="orbax_upgrade_guide.html" title="next page"> <div class="prev-next-info"> <p class="prev-next-subtitle">next</p> <p class="prev-next-title">Migrate checkpointing to Orbax</p> </div> <i class="fa-solid fa-angle-right"></i> </a> </div> </footer> </div> <div class="bd-sidebar-secondary bd-toc"><div class="sidebar-secondary-items sidebar-secondary__inner"> <div class="sidebar-secondary-item"> <div class="page-toc tocsection onthispage"> <i class="fa-solid fa-list"></i> Contents </div> <nav class="bd-toc-nav page-toc"> <ul class="visible nav section-nav flex-column"> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#fc-layers">FC Layers</a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#convolutions">Convolutions</a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#convolutions-and-fc-layers">Convolutions and FC Layers</a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#batch-norm">Batch Norm</a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#average-pooling">Average Pooling</a></li> <li class="toc-h2 nav-item toc-entry"><a class="reference internal nav-link" href="#transposed-convolutions">Transposed Convolutions</a></li> </ul> </nav></div> </div></div> </div> <footer class="bd-footer-content"> <div class="bd-footer-content__inner container"> <div class="footer-item"> <p class="component-author"> By The Flax authors </p> </div> <div class="footer-item"> <p class="copyright"> © Copyright 2023, The Flax authors. <br/> </p> </div> <div class="footer-item"> </div> <div class="footer-item"> </div> </div> </footer> </main> </div> </div> <!-- Scripts loaded after <body> so the DOM is not blocked --> <script src="../../_static/scripts/bootstrap.js?digest=dfe6caa3a7d634c4db9b"></script> <script src="../../_static/scripts/pydata-sphinx-theme.js?digest=dfe6caa3a7d634c4db9b"></script> <footer class="bd-footer"> </footer> </body> </html>

Pages: 1 2 3 4 5 6 7 8 9 10