CINXE.COM

Keras 3 benchmarks

<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/getting_started/benchmarks/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Keras 3 benchmarks"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Keras 3 benchmarks"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Keras 3 benchmarks</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link active" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-sublink" href="/getting_started/intro_to_keras_for_engineers/">Introduction to Keras for engineers</a> <a class="nav-sublink active" href="/getting_started/benchmarks/">Keras 3 benchmarks</a> <a class="nav-sublink" href="/getting_started/ecosystem/">The Keras ecosystem</a> <a class="nav-sublink" href="/getting_started/faq/">Frequently Asked Questions</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparam Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/getting_started/'>Getting started</a> / Keras 3 benchmarks </div> <div class='k-content'> <h1 id="keras-3-benchmarks">Keras 3 benchmarks</h1> <p>We benchmark the three backends of Keras 3 (<a href="https://tensorflow.org/">TensorFlow</a>, <a href="https://jax.readthedocs.io/en/latest/">JAX</a>, <a href="https://pytorch.org/">PyTorch</a>) alongside Keras 2 with TensorFlow. Find code and setup details for reproducing our results <a href="https://github.com/haifeng-jin/keras-benchmarks/tree/v0.0.5">here</a>.</p> <h2 id="models">Models</h2> <p>We chose a set of popular computer vision and natural language processing models for both generative and non-generative AI tasks. See the table below for our selections.</p> <p><strong>Table 1</strong>: Models used in benchmarking.</p> <table> <thead> <tr> <th style="text-align: center;"></th> <th style="text-align: center;">Non-Generative</th> <th style="text-align: center;">Generative</th> </tr> </thead> <tbody> <tr> <td style="text-align: center;">CV</td> <td style="text-align: center;">SegmentAnything<sup>1</sup></td> <td style="text-align: center;">StableDiffusion<sup>2</sup></td> </tr> <tr> <td style="text-align: center;">NLP</td> <td style="text-align: center;">BERT<sup>3</sup></td> <td style="text-align: center;">Gemma<sup>4</sup>, Mistral<sup>5</sup></td> </tr> </tbody> </table> <p>We are not measuring the best possible performance achievable by each framework, but the out-of-the-box performance of common user workflows. With this goal in mind, we leveraged pre-existing implementations from KerasCV and KerasHub for the Keras versions of the models.</p> <h2 id="hardware">Hardware</h2> <p>All benchmarks are done with a single NVIDIA A100 GPU with 40GB of GPU memory on a Google Cloud Compute Engine of machine type <code>a2-highgpu-1g</code> with 12 vCPUs and 85GB host memory.</p> <h2 id="results">Results</h2> <p>Table 2 displays benchmarking results in milliseconds per step. Each step involves training or predicting on a single data batch. Results are averaged over 100 steps, excluding the first, which includes model creation and compilation overhead.</p> <p>For fair comparison, we use the same batch size across frameworks if it is the same model and task (fit or predict). However, for different models and tasks, due to their different sizes and architectures, we use different batch sizes to avoid either running out of memory (too large) or under GPU utilization (too small).</p> <p>For large language models (Gemma and Mistral), we also used the same batch size since they are the same model type with similar number of parameters (7B). We also benchmarked text generation with batch size equal to 1 since it is widely requested by the users. We used <code>bfloat16</code> precision for their training and inferencing, and LoRA<sup>6</sup> for their training (fine-tuning).</p> <p>To measure out-of-the-box performance, we try to use all default settings. For example, use high-level APIs (e.g. Use Keras <code>model.fit()</code>) with as little configuration as possible.</p> <p>Note that this is quite different from measuring an optimized implementation for a particular hardware/framework/model combination. Please refer to <a href="https://mlcommons.org/benchmarks/">MLPerf</a> for the best optimized results for different frameworks.</p> <p><strong>Table 2</strong>: Benchmarking results. The speed is measured in ms/step. Lower is better.</p> <table> <thead> <tr> <th style="text-align: center;"></th> <th style="text-align: right;">Batch<br>size</th> <th style="text-align: right;">Keras 2<br>(TensorFlow)</th> <th style="text-align: right;">Keras 3<br>(TensorFlow)</th> <th style="text-align: right;">Keras 3<br>(JAX)</th> <th style="text-align: right;">Keras 3<br>(PyTorch)<br>(eager)</th> <th style="text-align: right;">Keras 3<br>(best)</th> </tr> </thead> <tbody> <tr> <td style="text-align: center;"><strong>SegmentAnything<br>(fit)</strong></td> <td style="text-align: right;">1</td> <td style="text-align: right;">386.93</td> <td style="text-align: right;"><strong>355.25</strong></td> <td style="text-align: right;">361.69</td> <td style="text-align: right;">1,388.87</td> <td style="text-align: right;"><strong>355.25</strong></td> </tr> <tr> <td style="text-align: center;"><strong>SegmentAnything<br>(predict)</strong></td> <td style="text-align: right;">4</td> <td style="text-align: right;">1,859.27</td> <td style="text-align: right;">438.50</td> <td style="text-align: right;"><strong>376.34</strong></td> <td style="text-align: right;">1,720.96</td> <td style="text-align: right;"><strong>376.34</strong></td> </tr> <tr> <td style="text-align: center;"><strong>Stable Diffusion<br>(fit)</strong></td> <td style="text-align: right;">8</td> <td style="text-align: right;">1,023.21</td> <td style="text-align: right;">392.24</td> <td style="text-align: right;"><strong>391.21</strong></td> <td style="text-align: right;">823.44</td> <td style="text-align: right;"><strong>391.21</strong></td> </tr> <tr> <td style="text-align: center;"><strong>Stable Diffusion<br>(predict)</strong></td> <td style="text-align: right;">13</td> <td style="text-align: right;">649.71</td> <td style="text-align: right;"><strong>616.04</strong></td> <td style="text-align: right;">627.27</td> <td style="text-align: right;">1,337.17</td> <td style="text-align: right;"><strong>616.04</strong></td> </tr> <tr> <td style="text-align: center;"><strong>BERT<br>(fit)</strong></td> <td style="text-align: right;">32</td> <td style="text-align: right;">486.00</td> <td style="text-align: right;"><strong>214.49</strong></td> <td style="text-align: right;">222.37</td> <td style="text-align: right;">808.68</td> <td style="text-align: right;"><strong>214.49</strong></td> </tr> <tr> <td style="text-align: center;"><strong>BERT<br>(predict)</strong></td> <td style="text-align: right;">256</td> <td style="text-align: right;">470.12</td> <td style="text-align: right;">466.01</td> <td style="text-align: right;"><strong>418.72</strong></td> <td style="text-align: right;">1,865.98</td> <td style="text-align: right;"><strong>418.72</strong></td> </tr> <tr> <td style="text-align: center;"><strong>Gemma<br>(fit)</strong></td> <td style="text-align: right;">8</td> <td style="text-align: right;">NA</td> <td style="text-align: right;">232.52</td> <td style="text-align: right;">273.67</td> <td style="text-align: right;">525.15</td> <td style="text-align: right;"><strong>232.52</strong></td> </tr> <tr> <td style="text-align: center;"><strong>Gemma<br>(generate)</strong></td> <td style="text-align: right;">32</td> <td style="text-align: right;">NA</td> <td style="text-align: right;">1,134.91</td> <td style="text-align: right;"><strong>1,128.21</strong></td> <td style="text-align: right;">7,952.67<sup>*</sup></td> <td style="text-align: right;"><strong>1,128.21</strong></td> </tr> <tr> <td style="text-align: center;"><strong>Gemma<br>(generate)</strong></td> <td style="text-align: right;">1</td> <td style="text-align: right;">NA</td> <td style="text-align: right;">758.57</td> <td style="text-align: right;"><strong>703.46</strong></td> <td style="text-align: right;">7,649.40<sup>*</sup></td> <td style="text-align: right;"><strong>703.46</strong></td> </tr> <tr> <td style="text-align: center;"><strong>Mistral<br>(fit)</strong></td> <td style="text-align: right;">8</td> <td style="text-align: right;">NA</td> <td style="text-align: right;"><strong>185.92</strong></td> <td style="text-align: right;">213.22</td> <td style="text-align: right;">452.12</td> <td style="text-align: right;"><strong>185.92</strong></td> </tr> <tr> <td style="text-align: center;"><strong>Mistral<br>(generate)</strong></td> <td style="text-align: right;">32</td> <td style="text-align: right;">NA</td> <td style="text-align: right;">966.06</td> <td style="text-align: right;"><strong>957.25</strong></td> <td style="text-align: right;">10,932.59<sup>*</sup></td> <td style="text-align: right;"><strong>957.25</strong></td> </tr> <tr> <td style="text-align: center;"><strong>Mistral<br>(generate)</strong></td> <td style="text-align: right;">1</td> <td style="text-align: right;">NA</td> <td style="text-align: right;">743.28</td> <td style="text-align: right;"><strong>679.30</strong></td> <td style="text-align: right;">11,054.67<sup>*</sup></td> <td style="text-align: right;"><strong>679.30</strong></td> </tr> </tbody> </table> <p>* <em>LLM inference with the PyTorch backend is abnormally slow at this time because KerasHub uses static sequence padding, unlike HuggingFace. This will be addressed soon.</em></p> <h2 id="discussion">Discussion</h2> <h3 id="key-finding-1-there-is-no-best-backend">Key Finding 1: There is no "best" backend</h3> <p>Each of the three backends of Keras offers unique strengths. Crucially, from a performance standpoint, there's no single backend that consistently outpaces the others. The fastest backend often depends on your specific model architecture.</p> <p>This underscores the value of framework optionality when chasing optimal performance. Keras 3 empowers you to seamlessly switch backends, ensuring you find the ideal match for your model.</p> <h3 id="key-finding-2-keras-3-is-faster-than-keras-2">Key Finding 2: Keras 3 is faster than Keras 2</h3> <p>We also calculated the throughput (steps/ms) increase of Keras 3 (using its best-performing backend) over Keras 2 with TensorFlow from Table 1. Results are shown in the following figure.</p> <p><img alt="Figrue 2" src="https://i.imgur.com/jPncf0F.png" /></p> <p><strong>Figure 1</strong>: Keras 3 speedup over Keras 2 measured in throughput (steps/ms)</p> <p>Keras 3 consistently outperformed Keras 2 across all benchmarked models, with substantial speed increases in many cases. SegmentAnything inference saw a remarkable 380% boost, StableDiffusion training throughput increased by over 150%, and BERT training throughput rose by over 100%.</p> <p>Importantly, you would still see a performance boost even if you simply upgrade to Keras 3 and continue using the TensorFlow backend. This is mainly because Keras 2 uses more TensorFlow fused ops directly, which may be sub-optimal for XLA compilation in certain use cases.</p> <h2 id="conclusions">Conclusions</h2> <p>Framework performance depends heavily on the specific model. Keras 3 empowers you to select the fastest framework for your task – an option almost always to outperform both Keras 2.</p> <h2 id="references">References</h2> <p><sup>1</sup> Kirillov, Alexander, et al. "Segment anything." ICCV (2023).</p> <p><sup>2</sup> Rombach, Robin, et al. "High-resolution image synthesis with latent diffusion models." CVPR (2022).</p> <p><sup>3</sup> Kenton, Jacob, et al. "BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding." NAACL (2019).</p> <p><sup>4</sup> Banks, Jeanine, et al. "Gemma: Introducing new state-of-the-art open models." The Keyword, Google (2024).</p> <p><sup>5</sup> Jiang, Albert Q., et al. "Mistral 7B." arXiv preprint arXiv:2310.06825 (2023).</p> <p><sup>6</sup> Hu, Edward J., et al. "Lora: Low-rank adaptation of large language models." ICLR (2022).</p> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#keras-3-benchmarks'>Keras 3 benchmarks</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#models'>Models</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#hardware'>Hardware</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#results'>Results</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#discussion'>Discussion</a> </div> <div class='k-outline-depth-3'> <a href='#key-finding-1-there-is-no-best-backend'>Key Finding 1: There is no "best" backend</a> </div> <div class='k-outline-depth-3'> <a href='#key-finding-2-keras-3-is-faster-than-keras-2'>Key Finding 2: Keras 3 is faster than Keras 2</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#conclusions'>Conclusions</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#references'>References</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>

Pages: 1 2 3 4 5 6 7 8 9 10