CINXE.COM
Mixed precision
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/api/mixed_precision/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Mixed precision"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Mixed precision"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Mixed precision</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link active" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-sublink" href="/api/models/">Models API</a> <a class="nav-sublink" href="/api/layers/">Layers API</a> <a class="nav-sublink" href="/api/callbacks/">Callbacks API</a> <a class="nav-sublink" href="/api/ops/">Ops API</a> <a class="nav-sublink" href="/api/optimizers/">Optimizers</a> <a class="nav-sublink" href="/api/metrics/">Metrics</a> <a class="nav-sublink" href="/api/losses/">Losses</a> <a class="nav-sublink" href="/api/data_loading/">Data loading</a> <a class="nav-sublink" href="/api/datasets/">Built-in small datasets</a> <a class="nav-sublink" href="/api/applications/">Keras Applications</a> <a class="nav-sublink active" href="/api/mixed_precision/">Mixed precision</a> <a class="nav-sublink2" href="/api/mixed_precision/policy/">Mixed precision policy API</a> <a class="nav-sublink" href="/api/distribution/">Multi-device distribution</a> <a class="nav-sublink" href="/api/random/">RNG API</a> <a class="nav-sublink" href="/api/utils/">Utilities</a> <a class="nav-sublink" href="/api/keras_tuner/">KerasTuner</a> <a class="nav-sublink" href="/api/keras_cv/">KerasCV</a> <a class="nav-sublink" href="/api/keras_nlp/">KerasNLP</a> <a class="nav-sublink" href="/api/keras_hub/">KerasHub</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparameter Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> <a class="nav-link" href="/keras_cv/" role="tab" aria-selected="">KerasCV: Computer Vision Workflows</a> <a class="nav-link" href="/keras_nlp/" role="tab" aria-selected="">KerasNLP: Natural Language Workflows</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/api/'>Keras 3 API documentation</a> / Mixed precision </div> <div class='k-content'> <h1 id="mixed-precision">Mixed precision</h1> <h2 id="what-is-mixed-precision-training">What is mixed precision training?</h2> <p>Mixed precision training is the use of lower-precision operations (<code>float16</code> and <code>bfloat16</code>) in a model during training to make it run faster and use less memory. Using mixed precision can improve performance by more than 3 times on modern GPUs and 60% on TPUs.</p> <p>Today, most models use the <code>float32</code> dtype, which takes 32 bits of memory. However, there are two lower-precision dtypes, <code>float16</code> and <code>bfloat16</code>, each which take 16 bits of memory instead. Modern accelerators like Google TPUs and NVIDIA GPUs can run operations faster in the 16-bit dtypes, as they have specialized hardware to run 16-bit computations and 16-bit dtypes can be read from memory faster. Therefore, these lower-precision dtypes should be used whenever possible on those devices.</p> <p>However, variables storage (as well as certain sensitive computations) should still be in <code>float32</code> to preserve numerical stability. By using 16-bit precision whenever possible and keeping certain critical parts of the model in <code>float32</code>, the model will run faster, while training as well as when using 32-bit precision.</p> <h2 id="using-mixed-precision-training-in-keras">Using mixed precision training in Keras</h2> <p>The precision policy used by Keras layers or models is controled by a <code>keras.mixed_precision.DTypePolicy</code> instance. Each layer has its own <code>DTypePolicy</code>. You can either set it on an individual layer via the <code>dtype</code> argument (e.g. <code>MyLayer(..., dtype="mixed_float16")</code>), or you can set a global value to be used by all layers by default, via the utility <code>keras.mixed_precision.set_global_policy</code>.</p> <p>Typically, to start using mixed precision on GPU, you would simply call <code>keras.mixed_precision.set_global_policy("mixed_float16")</code> at the start of your program. On TPU, you would call <code>keras.mixed_precision.set_global_policy("mixed_bfloat16")</code>.</p> <h2 id="api-documentation">API documentation</h2> <h3 id="mixed-precision-policy-api"><a href="/api/mixed_precision/policy/">Mixed precision policy API</a></h3> <ul> <li><a href="/api/mixed_precision/policy/#dtypepolicy-class">DTypePolicy class</a></li> <li><a href="/api/mixed_precision/policy/#dtypepolicymap-class">DTypePolicyMap class</a></li> <li><a href="/api/mixed_precision/policy/#floatdtypepolicy-class">FloatDTypePolicy class</a></li> <li><a href="/api/mixed_precision/policy/#quantizeddtypepolicy-class">QuantizedDTypePolicy class</a></li> <li><a href="/api/mixed_precision/policy/#quantizedfloat8dtypepolicy-class">QuantizedFloat8DTypePolicy class</a></li> <li><a href="/api/mixed_precision/policy/#dtype_policy-function">dtype_policy function</a></li> <li><a href="/api/mixed_precision/policy/#set_dtype_policy-function">set_dtype_policy function</a></li> </ul> <h2 id="supported-hardware">Supported hardware</h2> <p>While mixed precision will run on most hardware, it will only speed up models on recent NVIDIA GPUs and Google TPUs. NVIDIA GPUs support using a mix of float16 and float32, while TPUs support a mix of bfloat16 and float32.</p> <p>Among NVIDIA GPUs, those with compute capability 7.0 or higher will see the greatest performance benefit from mixed precision because they have special hardware units, called Tensor Cores, to accelerate float16 matrix multiplications and convolutions. Older GPUs offer no math performance benefit for using mixed precision, however memory and bandwidth savings can enable some speedups. You can look up the compute capability for your GPU at NVIDIA's <a href="https://developer.nvidia.com/cuda-gpus">CUDA GPU web page</a>. Examples of GPUs that will benefit most from mixed precision include RTX GPUs, the V100, and the A100.</p> <p>Even on CPUs and older GPUs, where no speedup is expected, mixed precision APIs can still be used for unit testing, debugging, or just to try out the API. On CPUs, mixed precision will run significantly slower, however.</p> <p>You can check your GPU type with the following command:</p> <div class="codehilite"><pre><span></span><code>nvidia-smi -L </code></pre></div> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#mixed-precision'>Mixed precision</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#what-is-mixed-precision-training'>What is mixed precision training?</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#using-mixed-precision-training-in-keras'>Using mixed precision training in Keras</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#api-documentation'>API documentation</a> </div> <div class='k-outline-depth-3'> <a href='#mixed-precision-policy-api'>Mixed precision policy API</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#supported-hardware'>Supported hardware</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>