CINXE.COM

Stable Diffusion 3 in KerasHub!

<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/keras_hub/guides/stable_diffusion_3_in_keras_hub/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Stable Diffusion 3 in KerasHub!"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Stable Diffusion 3 in KerasHub!"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Stable Diffusion 3 in KerasHub!</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparam Tuning</a> <a class="nav-link active" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> <a class="nav-sublink" href="/keras_hub/getting_started/">Getting started</a> <a class="nav-sublink active" href="/keras_hub/guides/">Developer guides</a> <a class="nav-sublink2" href="/keras_hub/guides/upload/">Uploading Models</a> <a class="nav-sublink2 active" href="/keras_hub/guides/stable_diffusion_3_in_keras_hub/">Stable Diffusion 3</a> <a class="nav-sublink2" href="/keras_hub/guides/segment_anything_in_keras_hub/">Segment Anything</a> <a class="nav-sublink2" href="/keras_hub/guides/classification_with_keras_hub/">Image Classification</a> <a class="nav-sublink2" href="/keras_hub/guides/semantic_segmentation_deeplab_v3/">Semantic Segmentation</a> <a class="nav-sublink2" href="/keras_hub/guides/transformer_pretraining/">Pretraining a Transformer from scratch</a> <a class="nav-sublink" href="/keras_hub/api/">API documentation</a> <a class="nav-sublink" href="/keras_hub/presets/">Pretrained models list</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/keras_hub/'>KerasHub: Pretrained Models</a> / <a href='/keras_hub/guides/'>Developer guides</a> / Stable Diffusion 3 in KerasHub! </div> <div class='k-content'> <h1 id="stable-diffusion-3-in-kerashub">Stable Diffusion 3 in KerasHub!</h1> <p><strong>Author:</strong> <a href="https://github.com/james77777778">Hongyu Chiu</a>, <a href="https://twitter.com/fchollet">fchollet</a>, <a href="https://twitter.com/luke_wood_ml">lukewood</a>, <a href="https://github.com/divamgupta">divamgupta</a><br> <strong>Date created:</strong> 2024/10/09<br> <strong>Last modified:</strong> 2024/10/24<br> <strong>Description:</strong> Image generation using KerasHub's Stable Diffusion 3 model.</p> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/guides/ipynb/keras_hub/stable_diffusion_3_in_keras_hub.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/guides/keras_hub/stable_diffusion_3_in_keras_hub.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="overview">Overview</h2> <p>Stable Diffusion 3 is a powerful, open-source latent diffusion model (LDM) designed to generate high-quality novel images based on text prompts. Released by <a href="https://stability.ai/">Stability AI</a>, it was pre-trained on 1 billion images and fine-tuned on 33 million high-quality aesthetic and preference images , resulting in a greatly improved performance compared to previous version of Stable Diffusion models.</p> <p>In this guide, we will explore KerasHub's implementation of the <a href="https://huggingface.co/stabilityai/stable-diffusion-3-medium">Stable Diffusion 3 Medium</a> including text-to-image, image-to-image and inpaint tasks.</p> <p>To get started, let's install a few dependencies and get images for our demo:</p> <div class="codehilite"><pre><span></span><code><span class="err">!</span><span class="n">pip</span> <span class="n">install</span> <span class="o">-</span><span class="n">Uq</span> <span class="n">keras</span> <span class="err">!</span><span class="n">pip</span> <span class="n">install</span> <span class="o">-</span><span class="n">Uq</span> <span class="n">git</span><span class="o">+</span><span class="n">https</span><span class="p">:</span><span class="o">//</span><span class="n">github</span><span class="o">.</span><span class="n">com</span><span class="o">/</span><span class="n">keras</span><span class="o">-</span><span class="n">team</span><span class="o">/</span><span class="n">keras</span><span class="o">-</span><span class="n">hub</span><span class="o">.</span><span class="n">git</span> <span class="err">!</span><span class="n">wget</span> <span class="o">-</span><span class="n">O</span> <span class="n">mountain_dog</span><span class="o">.</span><span class="n">png</span> <span class="n">https</span><span class="p">:</span><span class="o">//</span><span class="n">raw</span><span class="o">.</span><span class="n">githubusercontent</span><span class="o">.</span><span class="n">com</span><span class="o">/</span><span class="n">keras</span><span class="o">-</span><span class="n">team</span><span class="o">/</span><span class="n">keras</span><span class="o">-</span><span class="n">io</span><span class="o">/</span><span class="n">master</span><span class="o">/</span><span class="n">guides</span><span class="o">/</span><span class="n">img</span><span class="o">/</span><span class="n">stable_diffusion_3_in_keras_hub</span><span class="o">/</span><span class="n">mountain_dog</span><span class="o">.</span><span class="n">png</span> <span class="err">!</span><span class="n">wget</span> <span class="o">-</span><span class="n">O</span> <span class="n">mountain_dog_mask</span><span class="o">.</span><span class="n">png</span> <span class="n">https</span><span class="p">:</span><span class="o">//</span><span class="n">raw</span><span class="o">.</span><span class="n">githubusercontent</span><span class="o">.</span><span class="n">com</span><span class="o">/</span><span class="n">keras</span><span class="o">-</span><span class="n">team</span><span class="o">/</span><span class="n">keras</span><span class="o">-</span><span class="n">io</span><span class="o">/</span><span class="n">master</span><span class="o">/</span><span class="n">guides</span><span class="o">/</span><span class="n">img</span><span class="o">/</span><span class="n">stable_diffusion_3_in_keras_hub</span><span class="o">/</span><span class="n">mountain_dog_mask</span><span class="o">.</span><span class="n">png</span> </code></pre></div> <div class="codehilite"><pre><span></span><code><span class="kn">import</span><span class="w"> </span><span class="nn">os</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s2">&quot;KERAS_BACKEND&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="s2">&quot;jax&quot;</span> <span class="kn">import</span><span class="w"> </span><span class="nn">time</span> <span class="kn">import</span><span class="w"> </span><span class="nn">keras</span> <span class="kn">import</span><span class="w"> </span><span class="nn">keras_hub</span> <span class="kn">import</span><span class="w"> </span><span class="nn">matplotlib.pyplot</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">plt</span> <span class="kn">import</span><span class="w"> </span><span class="nn">numpy</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">np</span> <span class="kn">from</span><span class="w"> </span><span class="nn">PIL</span><span class="w"> </span><span class="kn">import</span> <span class="n">Image</span> </code></pre></div> <hr /> <h2 id="introduction">Introduction</h2> <p>Before diving into how latent diffusion models work, let's start by generating some images using KerasHub's APIs.</p> <p>To avoid reinitializing variables for different tasks, we'll instantiate and load the trained <code>backbone</code> and <code>preprocessor</code> using KerasHub's <code>from_preset</code> factory method. If you only want to perform one task at a time, you can use a simpler API like this:</p> <div class="codehilite"><pre><span></span><code><span class="n">text_to_image</span> <span class="o">=</span> <span class="n">keras_hub</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">StableDiffusion3TextToImage</span><span class="o">.</span><span class="n">from_preset</span><span class="p">(</span> <span class="s2">&quot;stable_diffusion_3_medium&quot;</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">&quot;float16&quot;</span> <span class="p">)</span> </code></pre></div> <p>That will automatically load and configure trained <code>backbone</code> and <code>preprocessor</code> for you.</p> <p>Note that in this guide, we'll use <code>image_shape=(512, 512, 3)</code> for faster image generation. For higher-quality output, it's recommended to use the default size of <code>1024</code>. Since the entire backbone has about 3 billion parameters, which can be challenging to fit into a consumer-level GPU, we set <code>dtype="float16"</code> to reduce the usage of GPU memory &ndash; the officially released weights are also in float16.</p> <p>It is also worth noting that the preset "stable_diffusion_3_medium" excludes the T5XXL text encoder, as it requires significantly more GPU memory. The performace degradation is negligible in most cases. The weights, including T5XXL, will be available on KerasHub soon.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="nf">display_generated_images</span><span class="p">(</span><span class="n">images</span><span class="p">):</span> <span class="w"> </span><span class="sd">&quot;&quot;&quot;Helper function to display the images from the inputs.</span> <span class="sd"> This function accepts the following input formats:</span> <span class="sd"> - 3D numpy array.</span> <span class="sd"> - 4D numpy array: concatenated horizontally.</span> <span class="sd"> - List of 3D numpy arrays: concatenated horizontally.</span> <span class="sd"> &quot;&quot;&quot;</span> <span class="n">display_image</span> <span class="o">=</span> <span class="kc">None</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">images</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">):</span> <span class="k">if</span> <span class="n">images</span><span class="o">.</span><span class="n">ndim</span> <span class="o">==</span> <span class="mi">3</span><span class="p">:</span> <span class="n">display_image</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">fromarray</span><span class="p">(</span><span class="n">images</span><span class="p">)</span> <span class="k">elif</span> <span class="n">images</span><span class="o">.</span><span class="n">ndim</span> <span class="o">==</span> <span class="mi">4</span><span class="p">:</span> <span class="n">concated_images</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">images</span><span class="p">),</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="n">display_image</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">fromarray</span><span class="p">(</span><span class="n">concated_images</span><span class="p">)</span> <span class="k">elif</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">images</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span> <span class="n">concated_images</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">(</span><span class="n">images</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="n">display_image</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">fromarray</span><span class="p">(</span><span class="n">concated_images</span><span class="p">)</span> <span class="k">if</span> <span class="n">display_image</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;Unsupported input format.&quot;</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span> <span class="n">plt</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">&quot;off&quot;</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">display_image</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> <span class="n">plt</span><span class="o">.</span><span class="n">close</span><span class="p">()</span> <span class="n">backbone</span> <span class="o">=</span> <span class="n">keras_hub</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">StableDiffusion3Backbone</span><span class="o">.</span><span class="n">from_preset</span><span class="p">(</span> <span class="s2">&quot;stable_diffusion_3_medium&quot;</span><span class="p">,</span> <span class="n">image_shape</span><span class="o">=</span><span class="p">(</span><span class="mi">512</span><span class="p">,</span> <span class="mi">512</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">&quot;float16&quot;</span> <span class="p">)</span> <span class="n">preprocessor</span> <span class="o">=</span> <span class="n">keras_hub</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">StableDiffusion3TextToImagePreprocessor</span><span class="o">.</span><span class="n">from_preset</span><span class="p">(</span> <span class="s2">&quot;stable_diffusion_3_medium&quot;</span> <span class="p">)</span> <span class="n">text_to_image</span> <span class="o">=</span> <span class="n">keras_hub</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">StableDiffusion3TextToImage</span><span class="p">(</span><span class="n">backbone</span><span class="p">,</span> <span class="n">preprocessor</span><span class="p">)</span> </code></pre></div> <p>Next, we give it a prompt:</p> <div class="codehilite"><pre><span></span><code><span class="n">prompt</span> <span class="o">=</span> <span class="s2">&quot;Astronaut in a jungle, cold color palette, muted colors, detailed, 8k&quot;</span> <span class="c1"># When using JAX or TensorFlow backends, you might experience a significant</span> <span class="c1"># compilation time during the first `generate()` call. The subsequent</span> <span class="c1"># `generate()` call speedup highlights the power of JIT compilation and caching</span> <span class="c1"># in frameworks like JAX and TensorFlow, making them well-suited for</span> <span class="c1"># high-performance deep learning tasks like image generation.</span> <span class="n">generated_image</span> <span class="o">=</span> <span class="n">text_to_image</span><span class="o">.</span><span class="n">generate</span><span class="p">(</span><span class="n">prompt</span><span class="p">)</span> <span class="n">display_generated_images</span><span class="p">(</span><span class="n">generated_image</span><span class="p">)</span> </code></pre></div> <p><img alt="png" src="/img/guides/stable_diffusion_3_in_keras_hub/stable_diffusion_3_in_keras_hub_7_0.png" /></p> <p>Pretty impressive! But how does this work?</p> <p>Let's dig into what "latent diffusion model" means.</p> <p>Consider the concept of "super-resolution," where a deep learning model "denoises" an input image, turning it into a higher-resolution version. The model uses its training data distribution to hallucinate the visual details that are most likely given the input. To learn more about super-resolution, you can check out the following Keras.io tutorials:</p> <ul> <li><a href="https://keras.io/examples/vision/super_resolution_sub_pixel/">Image Super-Resolution using an Efficient Sub-Pixel CNN</a></li> <li><a href="https://keras.io/examples/vision/edsr/">Enhanced Deep Residual Networks for single-image super-resolution</a></li> </ul> <p><img alt="Super-resolution" src="https://i.imgur.com/M0XdqOo.png" /></p> <p>When we push this idea to the limit, we may start asking &ndash; what if we just run such a model on pure noise? The model would then "denoise the noise" and start hallucinating a brand new image. By repeating the process multiple times, we can get turn a small patch of noise into an increasingly clear and high-resolution artificial picture.</p> <p>This is the key idea of latent diffusion, proposed in <a href="https://arxiv.org/abs/2112.10752">High-Resolution Image Synthesis with Latent Diffusion Models</a>. To understand diffusion in depth, you can check the Keras.io tutorial <a href="https://keras.io/examples/generative/ddim/">Denoising Diffusion Implicit Models</a>.</p> <p><img alt="Denoising diffusion" src="https://i.imgur.com/FSCKtZq.gif" /></p> <p>To transition from latent diffusion to a text-to-image system, one key feature must be added: the ability to control the generated visual content using prompt keywords. In Stable Diffusion 3, the text encoders from the CLIP and T5XXL models are used to obtain text embeddings, which are then fed into the diffusion model to condition the diffusion process. This approach is based on the concept of "classifier-free guidance", proposed in <a href="https://arxiv.org/abs/2207.12598">Classifier-Free Diffusion Guidance</a>.</p> <p>When we combine these ideas, we get a high-level overview of the architecture of Stable Diffusion 3:</p> <ul> <li>Text encoders: Convert the text prompt into text embeddings.</li> <li>Diffusion model: Repeatedly "denoises" a smaller latent image patch.</li> <li>Decoder: Transforms the final latent patch into a higher-resolution image.</li> </ul> <p>First, the text prompt is projected into the latent space by multiple text encoders, which are pretrained and frozen language models. Next, the text embeddings, along with a randomly generated noise patch (typically from a Gaussian distribution), are then fed into the diffusion model. The diffusion model repeatly "denoises" the noise patch over a series of steps (the more steps, the clearer and more refined the image becomes &ndash; the default value is 28 steps). Finally, the latent patch is passed through the decoder from the VAE model to render the image in high resolution.</p> <p>The overview of the Stable Diffusion 3 architecture: <img alt="The Stable Diffusion 3 architecture" src="https://i.imgur.com/D9y0fWF.png" /></p> <p>This relatively simple system starts looking like magic once we train on billions of pictures and their captions. As Feynman said about the universe: <em>"It's not complicated, it's just a lot of it!"</em></p> <hr /> <h2 id="texttoimage-task">Text-to-image task</h2> <p>Now we know the basis of the Stable Diffusion 3 and the text-to-image task. Let's explore further using KerasHub APIs.</p> <p>To use KerasHub's APIs for efficient batch processing, we can provide the model with a list of prompts:</p> <div class="codehilite"><pre><span></span><code><span class="n">generated_images</span> <span class="o">=</span> <span class="n">text_to_image</span><span class="o">.</span><span class="n">generate</span><span class="p">([</span><span class="n">prompt</span><span class="p">]</span> <span class="o">*</span> <span class="mi">3</span><span class="p">)</span> <span class="n">display_generated_images</span><span class="p">(</span><span class="n">generated_images</span><span class="p">)</span> </code></pre></div> <p><img alt="png" src="/img/guides/stable_diffusion_3_in_keras_hub/stable_diffusion_3_in_keras_hub_10_0.png" /></p> <p>The <code>num_steps</code> parameter controls the number of denoising steps used during image generation. Increasing the number of steps typically leads to higher quality images at the expense of increased generation time. In Stable Diffusion 3, this parameter defaults to <code>28</code>.</p> <div class="codehilite"><pre><span></span><code><span class="n">num_steps</span> <span class="o">=</span> <span class="p">[</span><span class="mi">10</span><span class="p">,</span> <span class="mi">28</span><span class="p">,</span> <span class="mi">50</span><span class="p">]</span> <span class="n">generated_images</span> <span class="o">=</span> <span class="p">[]</span> <span class="k">for</span> <span class="n">n</span> <span class="ow">in</span> <span class="n">num_steps</span><span class="p">:</span> <span class="n">st</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span> <span class="n">generated_images</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">text_to_image</span><span class="o">.</span><span class="n">generate</span><span class="p">(</span><span class="n">prompt</span><span class="p">,</span> <span class="n">num_steps</span><span class="o">=</span><span class="n">n</span><span class="p">))</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Cost time (`num_steps=</span><span class="si">{</span><span class="n">n</span><span class="si">}</span><span class="s2">`): </span><span class="si">{</span><span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span><span class="w"> </span><span class="o">-</span><span class="w"> </span><span class="n">st</span><span class="si">:</span><span class="s2">.2f</span><span class="si">}</span><span class="s2">s&quot;</span><span class="p">)</span> <span class="n">display_generated_images</span><span class="p">(</span><span class="n">generated_images</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Cost time (`num_steps=10`): 1.35s Cost time (`num_steps=28`): 3.44s Cost time (`num_steps=50`): 6.18s </code></pre></div> </div> <p><img alt="png" src="/img/guides/stable_diffusion_3_in_keras_hub/stable_diffusion_3_in_keras_hub_12_3.png" /></p> <p>We can use <code>"negative_prompts"</code> to guide the model away from generating specific styles and elements. The input format becomes a dict with the keys <code>"prompts"</code> and <code>"negative_prompts"</code>.</p> <p>If <code>"negative_prompts"</code> is not provided, it will be interpreted as an unconditioned prompt with the default value of <code>""</code>.</p> <div class="codehilite"><pre><span></span><code><span class="n">generated_images</span> <span class="o">=</span> <span class="n">text_to_image</span><span class="o">.</span><span class="n">generate</span><span class="p">(</span> <span class="p">{</span> <span class="s2">&quot;prompts&quot;</span><span class="p">:</span> <span class="p">[</span><span class="n">prompt</span><span class="p">]</span> <span class="o">*</span> <span class="mi">3</span><span class="p">,</span> <span class="s2">&quot;negative_prompts&quot;</span><span class="p">:</span> <span class="p">[</span><span class="s2">&quot;Green color&quot;</span><span class="p">]</span> <span class="o">*</span> <span class="mi">3</span><span class="p">,</span> <span class="p">}</span> <span class="p">)</span> <span class="n">display_generated_images</span><span class="p">(</span><span class="n">generated_images</span><span class="p">)</span> </code></pre></div> <p><img alt="png" src="/img/guides/stable_diffusion_3_in_keras_hub/stable_diffusion_3_in_keras_hub_14_0.png" /></p> <p><code>guidance_scale</code> affects how much the <code>"prompts"</code> influences image generation. A lower value gives the model creativity to generate images that are more loosely related to the prompt. Higher values push the model to follow the prompt more closely. If this value is too high, you may observe some artifacts in the generated image. In Stable Diffusion 3, it defaults to <code>7.0</code>.</p> <div class="codehilite"><pre><span></span><code><span class="n">generated_images</span> <span class="o">=</span> <span class="p">[</span> <span class="n">text_to_image</span><span class="o">.</span><span class="n">generate</span><span class="p">(</span><span class="n">prompt</span><span class="p">,</span> <span class="n">guidance_scale</span><span class="o">=</span><span class="mf">2.5</span><span class="p">),</span> <span class="n">text_to_image</span><span class="o">.</span><span class="n">generate</span><span class="p">(</span><span class="n">prompt</span><span class="p">,</span> <span class="n">guidance_scale</span><span class="o">=</span><span class="mf">7.0</span><span class="p">),</span> <span class="n">text_to_image</span><span class="o">.</span><span class="n">generate</span><span class="p">(</span><span class="n">prompt</span><span class="p">,</span> <span class="n">guidance_scale</span><span class="o">=</span><span class="mf">10.5</span><span class="p">),</span> <span class="p">]</span> <span class="n">display_generated_images</span><span class="p">(</span><span class="n">generated_images</span><span class="p">)</span> </code></pre></div> <p><img alt="png" src="/img/guides/stable_diffusion_3_in_keras_hub/stable_diffusion_3_in_keras_hub_16_0.png" /></p> <p>Note that <code>negative_prompts</code> and <code>guidance_scale</code> are related. The formula in the implementation can be represented as follows: <code>predicted_noise = negative_noise + guidance_scale * (positive_noise - negative_noise)</code>.</p> <hr /> <h2 id="imagetoimage-task">Image-to-image task</h2> <p>A reference image can be used as a starting point for the diffusion process. This requires an additional module in the pipeline: the encoder from the VAE model.</p> <p>The reference image is encoded by the VAE encoder into the latent space, where noise is then added. The subsequent denoising steps follow the same procedure as the text-to-image task.</p> <p>The input format becomes a dict with the keys <code>"images"</code>, <code>"prompts"</code> and optionally <code>"negative_prompts"</code>.</p> <div class="codehilite"><pre><span></span><code><span class="n">image_to_image</span> <span class="o">=</span> <span class="n">keras_hub</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">StableDiffusion3ImageToImage</span><span class="p">(</span><span class="n">backbone</span><span class="p">,</span> <span class="n">preprocessor</span><span class="p">)</span> <span class="n">image</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="s2">&quot;mountain_dog.png&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">convert</span><span class="p">(</span><span class="s2">&quot;RGB&quot;</span><span class="p">)</span> <span class="n">image</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">resize</span><span class="p">((</span><span class="mi">512</span><span class="p">,</span> <span class="mi">512</span><span class="p">))</span> <span class="n">width</span><span class="p">,</span> <span class="n">height</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">size</span> <span class="c1"># Note that the values of the image must be in the range of [-1.0, 1.0].</span> <span class="n">rescale</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Rescaling</span><span class="p">(</span><span class="n">scale</span><span class="o">=</span><span class="mi">1</span> <span class="o">/</span> <span class="mf">127.5</span><span class="p">,</span> <span class="n">offset</span><span class="o">=-</span><span class="mf">1.0</span><span class="p">)</span> <span class="n">image_array</span> <span class="o">=</span> <span class="n">rescale</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">image</span><span class="p">))</span> <span class="n">prompt</span> <span class="o">=</span> <span class="s2">&quot;dog wizard, gandalf, lord of the rings, detailed, fantasy, cute, &quot;</span> <span class="n">prompt</span> <span class="o">+=</span> <span class="s2">&quot;adorable, Pixar, Disney, 8k&quot;</span> <span class="n">generated_image</span> <span class="o">=</span> <span class="n">image_to_image</span><span class="o">.</span><span class="n">generate</span><span class="p">(</span> <span class="p">{</span> <span class="s2">&quot;images&quot;</span><span class="p">:</span> <span class="n">image_array</span><span class="p">,</span> <span class="s2">&quot;prompts&quot;</span><span class="p">:</span> <span class="n">prompt</span><span class="p">,</span> <span class="p">}</span> <span class="p">)</span> <span class="n">display_generated_images</span><span class="p">(</span> <span class="p">[</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">image</span><span class="p">),</span> <span class="n">generated_image</span><span class="p">,</span> <span class="p">]</span> <span class="p">)</span> </code></pre></div> <p><img alt="png" src="/img/guides/stable_diffusion_3_in_keras_hub/stable_diffusion_3_in_keras_hub_19_0.png" /></p> <p>As you can see, a new image is generated based on the reference image and the prompt.</p> <p>The <code>strength</code> parameter plays a key role in determining how closely the generated image resembles the reference image. The value ranges from <code>[0.0, 1.0]</code> and defaults to <code>0.8</code> in Stable Diffusion 3.</p> <p>A higher <code>strength</code> value gives the model more “creativity” to generate an image that is different from the reference image. At a value of <code>1.0</code>, the reference image is completely ignored, making the task purely text-to-image.</p> <p>A lower <code>strength</code> value means the generated image is more similar to the reference image.</p> <div class="codehilite"><pre><span></span><code><span class="n">generated_images</span> <span class="o">=</span> <span class="p">[</span> <span class="n">image_to_image</span><span class="o">.</span><span class="n">generate</span><span class="p">(</span> <span class="p">{</span> <span class="s2">&quot;images&quot;</span><span class="p">:</span> <span class="n">image_array</span><span class="p">,</span> <span class="s2">&quot;prompts&quot;</span><span class="p">:</span> <span class="n">prompt</span><span class="p">,</span> <span class="p">},</span> <span class="n">strength</span><span class="o">=</span><span class="mf">0.7</span><span class="p">,</span> <span class="p">),</span> <span class="n">image_to_image</span><span class="o">.</span><span class="n">generate</span><span class="p">(</span> <span class="p">{</span> <span class="s2">&quot;images&quot;</span><span class="p">:</span> <span class="n">image_array</span><span class="p">,</span> <span class="s2">&quot;prompts&quot;</span><span class="p">:</span> <span class="n">prompt</span><span class="p">,</span> <span class="p">},</span> <span class="n">strength</span><span class="o">=</span><span class="mf">0.8</span><span class="p">,</span> <span class="p">),</span> <span class="n">image_to_image</span><span class="o">.</span><span class="n">generate</span><span class="p">(</span> <span class="p">{</span> <span class="s2">&quot;images&quot;</span><span class="p">:</span> <span class="n">image_array</span><span class="p">,</span> <span class="s2">&quot;prompts&quot;</span><span class="p">:</span> <span class="n">prompt</span><span class="p">,</span> <span class="p">},</span> <span class="n">strength</span><span class="o">=</span><span class="mf">0.9</span><span class="p">,</span> <span class="p">),</span> <span class="p">]</span> <span class="n">display_generated_images</span><span class="p">(</span><span class="n">generated_images</span><span class="p">)</span> </code></pre></div> <p><img alt="png" src="/img/guides/stable_diffusion_3_in_keras_hub/stable_diffusion_3_in_keras_hub_22_0.png" /></p> <hr /> <h2 id="inpaint-task">Inpaint task</h2> <p>Building upon the image-to-image task, we can also control the generated area using a mask. This process is called inpainting, where specific areas of an image are replaced or edited.</p> <p>Inpainting relies on a mask to determine which regions of the image to modify. The areas to inpaint are represented by white pixels (<code>True</code>), while the areas to preserve are represented by black pixels (<code>False</code>).</p> <p>For inpainting, the input is a dict with the keys <code>"images"</code>, <code>"masks"</code>, <code>"prompts"</code> and optionally <code>"negative_prompts"</code>.</p> <div class="codehilite"><pre><span></span><code><span class="n">inpaint</span> <span class="o">=</span> <span class="n">keras_hub</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">StableDiffusion3Inpaint</span><span class="p">(</span><span class="n">backbone</span><span class="p">,</span> <span class="n">preprocessor</span><span class="p">)</span> <span class="n">image</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="s2">&quot;mountain_dog.png&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">convert</span><span class="p">(</span><span class="s2">&quot;RGB&quot;</span><span class="p">)</span> <span class="n">image</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">resize</span><span class="p">((</span><span class="mi">512</span><span class="p">,</span> <span class="mi">512</span><span class="p">))</span> <span class="n">image_array</span> <span class="o">=</span> <span class="n">rescale</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">image</span><span class="p">))</span> <span class="c1"># Note that the mask values are of boolean dtype.</span> <span class="n">mask</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="s2">&quot;mountain_dog_mask.png&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">convert</span><span class="p">(</span><span class="s2">&quot;L&quot;</span><span class="p">)</span> <span class="n">mask</span> <span class="o">=</span> <span class="n">mask</span><span class="o">.</span><span class="n">resize</span><span class="p">((</span><span class="mi">512</span><span class="p">,</span> <span class="mi">512</span><span class="p">))</span> <span class="n">mask_array</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">mask</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">&quot;bool&quot;</span><span class="p">)</span> <span class="n">prompt</span> <span class="o">=</span> <span class="s2">&quot;a black cat with glowing eyes, cute, adorable, disney, pixar, highly &quot;</span> <span class="n">prompt</span> <span class="o">+=</span> <span class="s2">&quot;detailed, 8k&quot;</span> <span class="n">generated_image</span> <span class="o">=</span> <span class="n">inpaint</span><span class="o">.</span><span class="n">generate</span><span class="p">(</span> <span class="p">{</span> <span class="s2">&quot;images&quot;</span><span class="p">:</span> <span class="n">image_array</span><span class="p">,</span> <span class="s2">&quot;masks&quot;</span><span class="p">:</span> <span class="n">mask_array</span><span class="p">,</span> <span class="s2">&quot;prompts&quot;</span><span class="p">:</span> <span class="n">prompt</span><span class="p">,</span> <span class="p">}</span> <span class="p">)</span> <span class="n">display_generated_images</span><span class="p">(</span> <span class="p">[</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">image</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">mask</span><span class="o">.</span><span class="n">convert</span><span class="p">(</span><span class="s2">&quot;RGB&quot;</span><span class="p">)),</span> <span class="n">generated_image</span><span class="p">,</span> <span class="p">]</span> <span class="p">)</span> </code></pre></div> <p><img alt="png" src="/img/guides/stable_diffusion_3_in_keras_hub/stable_diffusion_3_in_keras_hub_24_0.png" /></p> <p>Fantastic! The dog is replaced by a cute black cat, but unlike image-to-image, the background is preserved.</p> <p>Note that inpainting task also includes <code>strength</code> parameter to control the image generation, with the default value of <code>0.6</code> in Stable Diffusion 3.</p> <hr /> <h2 id="conclusion">Conclusion</h2> <p>KerasHub's <code>StableDiffusion3</code> supports a variety of applications and, with the help of Keras 3, enables running the model on TensorFlow, JAX, and PyTorch!</p> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#stable-diffusion-3-in-kerashub'>Stable Diffusion 3 in KerasHub!</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#overview'>Overview</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#texttoimage-task'>Text-to-image task</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#imagetoimage-task'>Image-to-image task</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#inpaint-task'>Inpaint task</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#conclusion'>Conclusion</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>

Pages: 1 2 3 4 5 6 7 8 9 10