CINXE.COM
Mixed precision policy API
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/api/mixed_precision/policy/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Mixed precision policy API"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Mixed precision policy API"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Mixed precision policy API</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-link active" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-sublink" href="/api/models/">Models API</a> <a class="nav-sublink" href="/api/layers/">Layers API</a> <a class="nav-sublink" href="/api/callbacks/">Callbacks API</a> <a class="nav-sublink" href="/api/ops/">Ops API</a> <a class="nav-sublink" href="/api/optimizers/">Optimizers</a> <a class="nav-sublink" href="/api/metrics/">Metrics</a> <a class="nav-sublink" href="/api/losses/">Losses</a> <a class="nav-sublink" href="/api/data_loading/">Data loading</a> <a class="nav-sublink" href="/api/datasets/">Built-in small datasets</a> <a class="nav-sublink" href="/api/applications/">Keras Applications</a> <a class="nav-sublink active" href="/api/mixed_precision/">Mixed precision</a> <a class="nav-sublink2 active" href="/api/mixed_precision/policy/">Mixed precision policy API</a> <a class="nav-sublink" href="/api/distribution/">Multi-device distribution</a> <a class="nav-sublink" href="/api/random/">RNG API</a> <a class="nav-sublink" href="/api/utils/">Utilities</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparam Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/api/'>Keras 3 API documentation</a> / <a href='/api/mixed_precision/'>Mixed precision</a> / Mixed precision policy API </div> <div class='k-content'> <h1 id="mixed-precision-policy-api">Mixed precision policy API</h1> <p><span style="float:right;"><a href="https://github.com/keras-team/keras/tree/v3.8.0/keras/src/dtype_policies/dtype_policy.py#L9">[source]</a></span></p> <h3 id="dtypepolicy-class"><code>DTypePolicy</code> class</h3> <div class="codehilite"><pre><span></span><code><span class="n">keras</span><span class="o">.</span><span class="n">dtype_policies</span><span class="o">.</span><span class="n">DTypePolicy</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span> </code></pre></div> <p>A dtype policy for a Keras layer.</p> <p>A dtype policy determines a layer's computation and variable dtypes. Each layer has a policy. Policies can be passed to the <code>dtype</code> argument of layer constructors, or a global policy can be set with <a href="/api/mixed_precision/policy#setdtypepolicy-function"><code>keras.config.set_dtype_policy</code></a>.</p> <p><strong>Arguments</strong></p> <ul> <li><strong>name</strong>: The policy name, which determines the compute and variable dtypes. Can be any dtype name, such as <code>"float32"</code> or <code>"float64"</code>, which causes both the compute and variable dtypes will be that dtype. Can also be the string <code>"mixed_float16"</code> or <code>"mixed_bfloat16"</code>, which causes the compute dtype to be <code>float16</code> or <code>bfloat16</code> and the variable dtype to be <code>float32</code>.</li> </ul> <p>Typically you only need to interact with dtype policies when using mixed precision, which is the use of float16 or bfloat16 for computations and float32 for variables. This is why the term <code>mixed_precision</code> appears in the API name. Mixed precision can be enabled by passing <code>"mixed_float16"</code> or <code>"mixed_bfloat16"</code> to <code>keras.mixed_precision.set_dtype_policy()</code>.</p> <div class="codehilite"><pre><span></span><code><span class="o">>>></span> <span class="n">keras</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">set_dtype_policy</span><span class="p">(</span><span class="s2">"mixed_float16"</span><span class="p">)</span> <span class="o">>>></span> <span class="n">layer1</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">10</span><span class="p">)</span> <span class="o">>>></span> <span class="n">layer1</span><span class="o">.</span><span class="n">dtype_policy</span> <span class="c1"># layer1 will automatically use mixed precision</span> <span class="o"><</span><span class="n">DTypePolicy</span> <span class="s2">"mixed_float16"</span><span class="o">></span> <span class="o">>>></span> <span class="c1"># Can optionally override layer to use float32</span> <span class="o">>>></span> <span class="c1"># instead of mixed precision.</span> <span class="o">>>></span> <span class="n">layer2</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"float32"</span><span class="p">)</span> <span class="o">>>></span> <span class="n">layer2</span><span class="o">.</span><span class="n">dtype_policy</span> <span class="o"><</span><span class="n">DTypePolicy</span> <span class="s2">"float32"</span><span class="o">></span> <span class="o">>>></span> <span class="c1"># Set policy back to initial float32.</span> <span class="o">>>></span> <span class="n">keras</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">set_dtype_policy</span><span class="p">(</span><span class="s1">'float32'</span><span class="p">)</span> </code></pre></div> <p>In the example above, passing <code>dtype="float32"</code> to the layer is equivalent to passing <code>dtype=keras.config.DTypePolicy("float32")</code>. In general, passing a dtype policy name to a layer is equivalent to passing the corresponding policy, so it is never necessary to explicitly construct a <code>DTypePolicy</code> object.</p> <hr /> <p><span style="float:right;"><a href="https://github.com/keras-team/keras/tree/v3.8.0/keras/src/dtype_policies/dtype_policy_map.py#L9">[source]</a></span></p> <h3 id="dtypepolicymap-class"><code>DTypePolicyMap</code> class</h3> <div class="codehilite"><pre><span></span><code><span class="n">keras</span><span class="o">.</span><span class="n">dtype_policies</span><span class="o">.</span><span class="n">DTypePolicyMap</span><span class="p">(</span><span class="n">default_policy</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">policy_map</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span> </code></pre></div> <p>Dict-like object mapping layer paths to <code>DTypePolicy</code> instances.</p> <p><code>DTypePolicyMap</code> can be used in <code>get_config</code> in layers and subclasses to support a complex configurations of dtype policies.</p> <p>For example, we can modify <code>get_config</code> in <code>layers.MultiHeadAttention</code> as follows to support the mixing of dtype policies, such as quantization.</p> <div class="codehilite"><pre><span></span><code><span class="nd">@keras</span><span class="o">.</span><span class="n">saving</span><span class="o">.</span><span class="n">register_keras_serializable</span><span class="p">(</span><span class="s2">"MyPackage"</span><span class="p">)</span> <span class="k">class</span><span class="w"> </span><span class="nc">MyMultiHeadAttention</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">MultiHeadAttention</span><span class="p">):</span> <span class="k">def</span><span class="w"> </span><span class="nf">get_config</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="n">config</span> <span class="o">=</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">get_config</span><span class="p">()</span> <span class="n">dtype_policy_map</span> <span class="o">=</span> <span class="n">dtype_policies</span><span class="o">.</span><span class="n">DTypePolicyMap</span><span class="p">()</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">_flatten_layers</span><span class="p">():</span> <span class="k">if</span> <span class="n">layer</span><span class="o">.</span><span class="n">dtype_policy</span><span class="o">.</span><span class="n">quantization_mode</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> <span class="n">dtype_policy_map</span><span class="p">[</span><span class="n">layer</span><span class="o">.</span><span class="n">path</span><span class="p">]</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">dtype_policy</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">dtype_policy_map</span><span class="p">)</span> <span class="o">></span> <span class="mi">0</span><span class="p">:</span> <span class="n">config</span><span class="o">.</span><span class="n">update</span><span class="p">({</span><span class="s2">"dtype"</span><span class="p">:</span> <span class="n">dtype_policy_map</span><span class="p">})</span> <span class="k">return</span> <span class="n">config</span> </code></pre></div> <p>Internally, <code>DTypePolicyMap</code> uses a string as a key and a <code>DTypePolicy</code> as the value. Typically, the key used for querying is the <code>Layer.path</code>. However, it is also possible to set a regex as the key. See the docstring of <code>get</code> for more details.</p> <p>See below for a usage example. You can define the naming schema of the <code>DTypePolicy</code>, and then retrieve the corresponding <code>DTypePolicy</code> instance.</p> <div class="codehilite"><pre><span></span><code><span class="n">dtype_policy_map</span> <span class="o">=</span> <span class="n">DTypePolicyMap</span><span class="p">()</span> <span class="n">dtype_policy_map</span><span class="p">[</span><span class="s2">"layer/dense_0"</span><span class="p">]</span> <span class="o">=</span> <span class="n">DTypePolicy</span><span class="p">(</span><span class="s2">"bfloat16"</span><span class="p">)</span> <span class="n">dtype_policy_map</span><span class="p">[</span><span class="s2">"layer/dense_1"</span><span class="p">]</span> <span class="o">=</span> <span class="n">QuantizedDTypePolicy</span><span class="p">(</span><span class="s2">"int8"</span><span class="p">,</span> <span class="s2">"bfloat16"</span><span class="p">)</span> <span class="n">policy_0</span> <span class="o">=</span> <span class="n">dtype_policy_map</span><span class="p">[</span><span class="s2">"layer/dense_0"</span><span class="p">]</span> <span class="n">policy_1</span> <span class="o">=</span> <span class="n">dtype_policy_map</span><span class="p">[</span><span class="s2">"layer/dense_1"</span><span class="p">]</span> <span class="n">policy_2</span> <span class="o">=</span> <span class="n">dtype_policy_map</span><span class="p">[</span><span class="s2">"layer/dense_2"</span><span class="p">]</span> <span class="c1"># No hit</span> <span class="k">assert</span> <span class="n">policy_0</span> <span class="o">==</span> <span class="n">DTypePolicy</span><span class="p">(</span><span class="s2">"bfloat16"</span><span class="p">)</span> <span class="k">assert</span> <span class="n">policy_1</span> <span class="o">==</span> <span class="n">QuantizedDTypePolicy</span><span class="p">(</span><span class="s2">"int8"</span><span class="p">,</span> <span class="s2">"bfloat16"</span><span class="p">)</span> <span class="k">assert</span> <span class="n">policy_2</span> <span class="o">==</span> <span class="n">keras</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">dtype_policy</span><span class="p">()</span> </code></pre></div> <p><strong>Arguments</strong></p> <ul> <li><strong>default_policy</strong>: An optional <code>DTypePolicy</code> instance specifying the default dtype policy. If not specified, the value will default to <code>keras.config.dtype_policy()</code>.</li> <li><strong>policy_map</strong>: An optional dict that maps string to <code>DTypePolicy</code> instances. Defaults to <code>None</code></li> </ul> <hr /> <p><span style="float:right;"><a href="https://github.com/keras-team/keras/tree/v3.8.0/keras/src/dtype_policies/dtype_policy.py#L207">[source]</a></span></p> <h3 id="floatdtypepolicy-class"><code>FloatDTypePolicy</code> class</h3> <div class="codehilite"><pre><span></span><code><span class="n">keras</span><span class="o">.</span><span class="n">dtype_policies</span><span class="o">.</span><span class="n">FloatDTypePolicy</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span> </code></pre></div> <p>A dtype policy for a Keras layer.</p> <p>A dtype policy determines a layer's computation and variable dtypes. Each layer has a policy. Policies can be passed to the <code>dtype</code> argument of layer constructors, or a global policy can be set with <a href="/api/mixed_precision/policy#setdtypepolicy-function"><code>keras.config.set_dtype_policy</code></a>.</p> <p><strong>Arguments</strong></p> <ul> <li><strong>name</strong>: The policy name, which determines the compute and variable dtypes. Can be any dtype name, such as <code>"float32"</code> or <code>"float64"</code>, which causes both the compute and variable dtypes will be that dtype. Can also be the string <code>"mixed_float16"</code> or <code>"mixed_bfloat16"</code>, which causes the compute dtype to be <code>float16</code> or <code>bfloat16</code> and the variable dtype to be <code>float32</code>.</li> </ul> <p>Typically you only need to interact with dtype policies when using mixed precision, which is the use of float16 or bfloat16 for computations and float32 for variables. This is why the term <code>mixed_precision</code> appears in the API name. Mixed precision can be enabled by passing <code>"mixed_float16"</code> or <code>"mixed_bfloat16"</code> to <code>keras.mixed_precision.set_dtype_policy()</code>.</p> <div class="codehilite"><pre><span></span><code><span class="o">>>></span> <span class="n">keras</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">set_dtype_policy</span><span class="p">(</span><span class="s2">"mixed_float16"</span><span class="p">)</span> <span class="o">>>></span> <span class="n">layer1</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">10</span><span class="p">)</span> <span class="o">>>></span> <span class="n">layer1</span><span class="o">.</span><span class="n">dtype_policy</span> <span class="c1"># layer1 will automatically use mixed precision</span> <span class="o"><</span><span class="n">DTypePolicy</span> <span class="s2">"mixed_float16"</span><span class="o">></span> <span class="o">>>></span> <span class="c1"># Can optionally override layer to use float32</span> <span class="o">>>></span> <span class="c1"># instead of mixed precision.</span> <span class="o">>>></span> <span class="n">layer2</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"float32"</span><span class="p">)</span> <span class="o">>>></span> <span class="n">layer2</span><span class="o">.</span><span class="n">dtype_policy</span> <span class="o"><</span><span class="n">DTypePolicy</span> <span class="s2">"float32"</span><span class="o">></span> <span class="o">>>></span> <span class="c1"># Set policy back to initial float32.</span> <span class="o">>>></span> <span class="n">keras</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">set_dtype_policy</span><span class="p">(</span><span class="s1">'float32'</span><span class="p">)</span> </code></pre></div> <p>In the example above, passing <code>dtype="float32"</code> to the layer is equivalent to passing <code>dtype=keras.config.DTypePolicy("float32")</code>. In general, passing a dtype policy name to a layer is equivalent to passing the corresponding policy, so it is never necessary to explicitly construct a <code>DTypePolicy</code> object.</p> <hr /> <p><span style="float:right;"><a href="https://github.com/keras-team/keras/tree/v3.8.0/keras/src/dtype_policies/dtype_policy.py#L215">[source]</a></span></p> <h3 id="quantizeddtypepolicy-class"><code>QuantizedDTypePolicy</code> class</h3> <div class="codehilite"><pre><span></span><code><span class="n">keras</span><span class="o">.</span><span class="n">dtype_policies</span><span class="o">.</span><span class="n">QuantizedDTypePolicy</span><span class="p">(</span><span class="n">mode</span><span class="p">,</span> <span class="n">source_name</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span> </code></pre></div> <p>A dtype policy for a Keras layer.</p> <p>A dtype policy determines a layer's computation and variable dtypes. Each layer has a policy. Policies can be passed to the <code>dtype</code> argument of layer constructors, or a global policy can be set with <a href="/api/mixed_precision/policy#setdtypepolicy-function"><code>keras.config.set_dtype_policy</code></a>.</p> <p><strong>Arguments</strong></p> <ul> <li><strong>name</strong>: The policy name, which determines the compute and variable dtypes. Can be any dtype name, such as <code>"float32"</code> or <code>"float64"</code>, which causes both the compute and variable dtypes will be that dtype. Can also be the string <code>"mixed_float16"</code> or <code>"mixed_bfloat16"</code>, which causes the compute dtype to be <code>float16</code> or <code>bfloat16</code> and the variable dtype to be <code>float32</code>.</li> </ul> <p>Typically you only need to interact with dtype policies when using mixed precision, which is the use of float16 or bfloat16 for computations and float32 for variables. This is why the term <code>mixed_precision</code> appears in the API name. Mixed precision can be enabled by passing <code>"mixed_float16"</code> or <code>"mixed_bfloat16"</code> to <code>keras.mixed_precision.set_dtype_policy()</code>.</p> <div class="codehilite"><pre><span></span><code><span class="o">>>></span> <span class="n">keras</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">set_dtype_policy</span><span class="p">(</span><span class="s2">"mixed_float16"</span><span class="p">)</span> <span class="o">>>></span> <span class="n">layer1</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">10</span><span class="p">)</span> <span class="o">>>></span> <span class="n">layer1</span><span class="o">.</span><span class="n">dtype_policy</span> <span class="c1"># layer1 will automatically use mixed precision</span> <span class="o"><</span><span class="n">DTypePolicy</span> <span class="s2">"mixed_float16"</span><span class="o">></span> <span class="o">>>></span> <span class="c1"># Can optionally override layer to use float32</span> <span class="o">>>></span> <span class="c1"># instead of mixed precision.</span> <span class="o">>>></span> <span class="n">layer2</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"float32"</span><span class="p">)</span> <span class="o">>>></span> <span class="n">layer2</span><span class="o">.</span><span class="n">dtype_policy</span> <span class="o"><</span><span class="n">DTypePolicy</span> <span class="s2">"float32"</span><span class="o">></span> <span class="o">>>></span> <span class="c1"># Set policy back to initial float32.</span> <span class="o">>>></span> <span class="n">keras</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">set_dtype_policy</span><span class="p">(</span><span class="s1">'float32'</span><span class="p">)</span> </code></pre></div> <p>In the example above, passing <code>dtype="float32"</code> to the layer is equivalent to passing <code>dtype=keras.config.DTypePolicy("float32")</code>. In general, passing a dtype policy name to a layer is equivalent to passing the corresponding policy, so it is never necessary to explicitly construct a <code>DTypePolicy</code> object.</p> <hr /> <p><span style="float:right;"><a href="https://github.com/keras-team/keras/tree/v3.8.0/keras/src/dtype_policies/dtype_policy.py#L259">[source]</a></span></p> <h3 id="quantizedfloat8dtypepolicy-class"><code>QuantizedFloat8DTypePolicy</code> class</h3> <div class="codehilite"><pre><span></span><code><span class="n">keras</span><span class="o">.</span><span class="n">dtype_policies</span><span class="o">.</span><span class="n">QuantizedFloat8DTypePolicy</span><span class="p">(</span> <span class="n">mode</span><span class="p">,</span> <span class="n">source_name</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">amax_history_length</span><span class="o">=</span><span class="mi">1024</span> <span class="p">)</span> </code></pre></div> <p>A dtype policy for a Keras layer.</p> <p>A dtype policy determines a layer's computation and variable dtypes. Each layer has a policy. Policies can be passed to the <code>dtype</code> argument of layer constructors, or a global policy can be set with <a href="/api/mixed_precision/policy#setdtypepolicy-function"><code>keras.config.set_dtype_policy</code></a>.</p> <p><strong>Arguments</strong></p> <ul> <li><strong>name</strong>: The policy name, which determines the compute and variable dtypes. Can be any dtype name, such as <code>"float32"</code> or <code>"float64"</code>, which causes both the compute and variable dtypes will be that dtype. Can also be the string <code>"mixed_float16"</code> or <code>"mixed_bfloat16"</code>, which causes the compute dtype to be <code>float16</code> or <code>bfloat16</code> and the variable dtype to be <code>float32</code>.</li> </ul> <p>Typically you only need to interact with dtype policies when using mixed precision, which is the use of float16 or bfloat16 for computations and float32 for variables. This is why the term <code>mixed_precision</code> appears in the API name. Mixed precision can be enabled by passing <code>"mixed_float16"</code> or <code>"mixed_bfloat16"</code> to <code>keras.mixed_precision.set_dtype_policy()</code>.</p> <div class="codehilite"><pre><span></span><code><span class="o">>>></span> <span class="n">keras</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">set_dtype_policy</span><span class="p">(</span><span class="s2">"mixed_float16"</span><span class="p">)</span> <span class="o">>>></span> <span class="n">layer1</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">10</span><span class="p">)</span> <span class="o">>>></span> <span class="n">layer1</span><span class="o">.</span><span class="n">dtype_policy</span> <span class="c1"># layer1 will automatically use mixed precision</span> <span class="o"><</span><span class="n">DTypePolicy</span> <span class="s2">"mixed_float16"</span><span class="o">></span> <span class="o">>>></span> <span class="c1"># Can optionally override layer to use float32</span> <span class="o">>>></span> <span class="c1"># instead of mixed precision.</span> <span class="o">>>></span> <span class="n">layer2</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"float32"</span><span class="p">)</span> <span class="o">>>></span> <span class="n">layer2</span><span class="o">.</span><span class="n">dtype_policy</span> <span class="o"><</span><span class="n">DTypePolicy</span> <span class="s2">"float32"</span><span class="o">></span> <span class="o">>>></span> <span class="c1"># Set policy back to initial float32.</span> <span class="o">>>></span> <span class="n">keras</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">set_dtype_policy</span><span class="p">(</span><span class="s1">'float32'</span><span class="p">)</span> </code></pre></div> <p>In the example above, passing <code>dtype="float32"</code> to the layer is equivalent to passing <code>dtype=keras.config.DTypePolicy("float32")</code>. In general, passing a dtype policy name to a layer is equivalent to passing the corresponding policy, so it is never necessary to explicitly construct a <code>DTypePolicy</code> object.</p> <hr /> <p><span style="float:right;"><a href="https://github.com/keras-team/keras/tree/v3.8.0/keras/src/dtype_policies/dtype_policy.py#L322">[source]</a></span></p> <h3 id="dtypepolicy-function"><code>dtype_policy</code> function</h3> <div class="codehilite"><pre><span></span><code><span class="n">keras</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">dtype_policy</span><span class="p">()</span> </code></pre></div> <p>Returns the current default dtype policy object.</p> <hr /> <p><span style="float:right;"><a href="https://github.com/keras-team/keras/tree/v3.8.0/keras/src/dtype_policies/dtype_policy.py#L291">[source]</a></span></p> <h3 id="setdtypepolicy-function"><code>set_dtype_policy</code> function</h3> <div class="codehilite"><pre><span></span><code><span class="n">keras</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">set_dtype_policy</span><span class="p">(</span><span class="n">policy</span><span class="p">)</span> </code></pre></div> <p>Sets the default dtype policy globally.</p> <p><strong>Example</strong></p> <div class="codehilite"><pre><span></span><code><span class="o">>>></span> <span class="n">keras</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">set_dtype_policy</span><span class="p">(</span><span class="s2">"mixed_float16"</span><span class="p">)</span> </code></pre></div> <hr /> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#mixed-precision-policy-api'>Mixed precision policy API</a> </div> <div class='k-outline-depth-3'> <a href='#dtypepolicy-class'><code>DTypePolicy</code> class</a> </div> <div class='k-outline-depth-3'> <a href='#dtypepolicymap-class'><code>DTypePolicyMap</code> class</a> </div> <div class='k-outline-depth-3'> <a href='#floatdtypepolicy-class'><code>FloatDTypePolicy</code> class</a> </div> <div class='k-outline-depth-3'> <a href='#quantizeddtypepolicy-class'><code>QuantizedDTypePolicy</code> class</a> </div> <div class='k-outline-depth-3'> <a href='#quantizedfloat8dtypepolicy-class'><code>QuantizedFloat8DTypePolicy</code> class</a> </div> <div class='k-outline-depth-3'> <a href='#dtypepolicy-function'><code>dtype_policy</code> function</a> </div> <div class='k-outline-depth-3'> <a href='#setdtypepolicy-function'><code>set_dtype_policy</code> function</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>