CINXE.COM
DCGAN to generate face images
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/generative/dcgan_overriding_train_step/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: DCGAN to generate face images"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: DCGAN to generate face images"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>DCGAN to generate face images</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink active" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink2" href="/examples/generative/ddim/">Denoising Diffusion Implicit Models</a> <a class="nav-sublink2" href="/examples/generative/random_walks_with_stable_diffusion_3/">A walk through latent space with Stable Diffusion 3</a> <a class="nav-sublink2" href="/examples/generative/dreambooth/">DreamBooth</a> <a class="nav-sublink2" href="/examples/generative/ddpm/">Denoising Diffusion Probabilistic Models</a> <a class="nav-sublink2" href="/examples/generative/fine_tune_via_textual_inversion/">Teach StableDiffusion new concepts via Textual Inversion</a> <a class="nav-sublink2" href="/examples/generative/finetune_stable_diffusion/">Fine-tuning Stable Diffusion</a> <a class="nav-sublink2" href="/examples/generative/vae/">Variational AutoEncoder</a> <a class="nav-sublink2 active" href="/examples/generative/dcgan_overriding_train_step/">GAN overriding Model.train_step</a> <a class="nav-sublink2" href="/examples/generative/wgan_gp/">WGAN-GP overriding Model.train_step</a> <a class="nav-sublink2" href="/examples/generative/conditional_gan/">Conditional GAN</a> <a class="nav-sublink2" href="/examples/generative/cyclegan/">CycleGAN</a> <a class="nav-sublink2" href="/examples/generative/gan_ada/">Data-efficient GANs with Adaptive Discriminator Augmentation</a> <a class="nav-sublink2" href="/examples/generative/deep_dream/">Deep Dream</a> <a class="nav-sublink2" href="/examples/generative/gaugan/">GauGAN for conditional image generation</a> <a class="nav-sublink2" href="/examples/generative/pixelcnn/">PixelCNN</a> <a class="nav-sublink2" href="/examples/generative/stylegan/">Face image generation with StyleGAN</a> <a class="nav-sublink2" href="/examples/generative/vq_vae/">Vector-Quantized Variational Autoencoders</a> <a class="nav-sublink2" href="/examples/generative/random_walks_with_stable_diffusion/">A walk through latent space with Stable Diffusion</a> <a class="nav-sublink2" href="/examples/generative/neural_style_transfer/">Neural style transfer</a> <a class="nav-sublink2" href="/examples/generative/adain/">Neural Style Transfer with AdaIN</a> <a class="nav-sublink2" href="/examples/generative/gpt2_text_generation_with_keras_hub/">GPT2 Text Generation with KerasHub</a> <a class="nav-sublink2" href="/examples/generative/text_generation_gpt/">GPT text generation from scratch with KerasHub</a> <a class="nav-sublink2" href="/examples/generative/text_generation_with_miniature_gpt/">Text generation with a miniature GPT</a> <a class="nav-sublink2" href="/examples/generative/lstm_character_level_text_generation/">Character-level text generation with LSTM</a> <a class="nav-sublink2" href="/examples/generative/text_generation_fnet/">Text Generation using FNet</a> <a class="nav-sublink2" href="/examples/generative/midi_generation_with_transformer/">Music Generation with Transformer Models</a> <a class="nav-sublink2" href="/examples/generative/molecule_generation/">Drug Molecule Generation with VAE</a> <a class="nav-sublink2" href="/examples/generative/wgan-graphs/">WGAN-GP with R-GCN for the generation of small molecular graphs</a> <a class="nav-sublink2" href="/examples/generative/real_nvp/">Density estimation using Real NVP</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparam Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/generative/'>Generative Deep Learning</a> / DCGAN to generate face images </div> <div class='k-content'> <h1 id="dcgan-to-generate-face-images">DCGAN to generate face images</h1> <p><strong>Author:</strong> <a href="https://twitter.com/fchollet">fchollet</a><br> <strong>Date created:</strong> 2019/04/29<br> <strong>Last modified:</strong> 2023/12/21<br> <strong>Description:</strong> A simple DCGAN trained using <code>fit()</code> by overriding <code>train_step</code> on CelebA images.</p> <div class='example_version_banner keras_3'>ⓘ This example uses Keras 3</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/generative/ipynb/dcgan_overriding_train_step.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/examples/generative/dcgan_overriding_train_step.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="setup">Setup</h2> <div class="codehilite"><pre><span></span><code><span class="kn">import</span><span class="w"> </span><span class="nn">keras</span> <span class="kn">import</span><span class="w"> </span><span class="nn">tensorflow</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">tf</span> <span class="kn">from</span><span class="w"> </span><span class="nn">keras</span><span class="w"> </span><span class="kn">import</span> <span class="n">layers</span> <span class="kn">from</span><span class="w"> </span><span class="nn">keras</span><span class="w"> </span><span class="kn">import</span> <span class="n">ops</span> <span class="kn">import</span><span class="w"> </span><span class="nn">matplotlib.pyplot</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">plt</span> <span class="kn">import</span><span class="w"> </span><span class="nn">os</span> <span class="kn">import</span><span class="w"> </span><span class="nn">gdown</span> <span class="kn">from</span><span class="w"> </span><span class="nn">zipfile</span><span class="w"> </span><span class="kn">import</span> <span class="n">ZipFile</span> </code></pre></div> <hr /> <h2 id="prepare-celeba-data">Prepare CelebA data</h2> <p>We'll use face images from the CelebA dataset, resized to 64x64.</p> <div class="codehilite"><pre><span></span><code><span class="n">os</span><span class="o">.</span><span class="n">makedirs</span><span class="p">(</span><span class="s2">"celeba_gan"</span><span class="p">)</span> <span class="n">url</span> <span class="o">=</span> <span class="s2">"https://drive.google.com/uc?id=1O7m1010EJjLE5QxLZiM9Fpjs7Oj6e684"</span> <span class="n">output</span> <span class="o">=</span> <span class="s2">"celeba_gan/data.zip"</span> <span class="n">gdown</span><span class="o">.</span><span class="n">download</span><span class="p">(</span><span class="n">url</span><span class="p">,</span> <span class="n">output</span><span class="p">,</span> <span class="n">quiet</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="k">with</span> <span class="n">ZipFile</span><span class="p">(</span><span class="s2">"celeba_gan/data.zip"</span><span class="p">,</span> <span class="s2">"r"</span><span class="p">)</span> <span class="k">as</span> <span class="n">zipobj</span><span class="p">:</span> <span class="n">zipobj</span><span class="o">.</span><span class="n">extractall</span><span class="p">(</span><span class="s2">"celeba_gan"</span><span class="p">)</span> </code></pre></div> <p>Create a dataset from our folder, and rescale the images to the [0-1] range:</p> <div class="codehilite"><pre><span></span><code><span class="n">dataset</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">image_dataset_from_directory</span><span class="p">(</span> <span class="s2">"celeba_gan"</span><span class="p">,</span> <span class="n">label_mode</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">image_size</span><span class="o">=</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="mi">64</span><span class="p">),</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">32</span> <span class="p">)</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">x</span> <span class="o">/</span> <span class="mf">255.0</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Found 202599 files. </code></pre></div> </div> <p>Let's display a sample image:</p> <div class="codehilite"><pre><span></span><code><span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">dataset</span><span class="p">:</span> <span class="n">plt</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">"off"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">((</span><span class="n">x</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> <span class="o">*</span> <span class="mi">255</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">"int32"</span><span class="p">)[</span><span class="mi">0</span><span class="p">])</span> <span class="k">break</span> </code></pre></div> <p><img alt="png" src="/img/examples/generative/dcgan_overriding_train_step/dcgan_overriding_train_step_8_0.png" /></p> <hr /> <h2 id="create-the-discriminator">Create the discriminator</h2> <p>It maps a 64x64 image to a binary classification score.</p> <div class="codehilite"><pre><span></span><code><span class="n">discriminator</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span> <span class="p">[</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="mi">64</span><span class="p">,</span> <span class="mi">3</span><span class="p">)),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">strides</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"same"</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">(</span><span class="n">negative_slope</span><span class="o">=</span><span class="mf">0.2</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">strides</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"same"</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">(</span><span class="n">negative_slope</span><span class="o">=</span><span class="mf">0.2</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">strides</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"same"</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">(</span><span class="n">negative_slope</span><span class="o">=</span><span class="mf">0.2</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Flatten</span><span class="p">(),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="mf">0.2</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"sigmoid"</span><span class="p">),</span> <span class="p">],</span> <span class="n">name</span><span class="o">=</span><span class="s2">"discriminator"</span><span class="p">,</span> <span class="p">)</span> <span class="n">discriminator</span><span class="o">.</span><span class="n">summary</span><span class="p">()</span> </code></pre></div> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold">Model: "discriminator"</span> </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ ┃<span style="font-weight: bold"> Layer (type) </span>┃<span style="font-weight: bold"> Output Shape </span>┃<span style="font-weight: bold"> Param # </span>┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ │ conv2d (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv2D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">3,136</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ leaky_re_lu (<span style="color: #0087ff; text-decoration-color: #0087ff">LeakyReLU</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_1 (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv2D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">16</span>, <span style="color: #00af00; text-decoration-color: #00af00">16</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">131,200</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ leaky_re_lu_1 (<span style="color: #0087ff; text-decoration-color: #0087ff">LeakyReLU</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">16</span>, <span style="color: #00af00; text-decoration-color: #00af00">16</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_2 (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv2D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">8</span>, <span style="color: #00af00; text-decoration-color: #00af00">8</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">262,272</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ leaky_re_lu_2 (<span style="color: #0087ff; text-decoration-color: #0087ff">LeakyReLU</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">8</span>, <span style="color: #00af00; text-decoration-color: #00af00">8</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ flatten (<span style="color: #0087ff; text-decoration-color: #0087ff">Flatten</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">8192</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ dropout (<span style="color: #0087ff; text-decoration-color: #0087ff">Dropout</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">8192</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ dense (<span style="color: #0087ff; text-decoration-color: #0087ff">Dense</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">1</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">8,193</span> │ └─────────────────────────────────┴───────────────────────────┴────────────┘ </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Total params: </span><span style="color: #00af00; text-decoration-color: #00af00">404,801</span> (1.54 MB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">404,801</span> (1.54 MB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Non-trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">0</span> (0.00 B) </pre> <hr /> <h2 id="create-the-generator">Create the generator</h2> <p>It mirrors the discriminator, replacing <code>Conv2D</code> layers with <code>Conv2DTranspose</code> layers.</p> <div class="codehilite"><pre><span></span><code><span class="n">latent_dim</span> <span class="o">=</span> <span class="mi">128</span> <span class="n">generator</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span> <span class="p">[</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">latent_dim</span><span class="p">,)),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">8</span> <span class="o">*</span> <span class="mi">8</span> <span class="o">*</span> <span class="mi">128</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Reshape</span><span class="p">((</span><span class="mi">8</span><span class="p">,</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">128</span><span class="p">)),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2DTranspose</span><span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">strides</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"same"</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">(</span><span class="n">negative_slope</span><span class="o">=</span><span class="mf">0.2</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2DTranspose</span><span class="p">(</span><span class="mi">256</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">strides</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"same"</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">(</span><span class="n">negative_slope</span><span class="o">=</span><span class="mf">0.2</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2DTranspose</span><span class="p">(</span><span class="mi">512</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">strides</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"same"</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">(</span><span class="n">negative_slope</span><span class="o">=</span><span class="mf">0.2</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"same"</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"sigmoid"</span><span class="p">),</span> <span class="p">],</span> <span class="n">name</span><span class="o">=</span><span class="s2">"generator"</span><span class="p">,</span> <span class="p">)</span> <span class="n">generator</span><span class="o">.</span><span class="n">summary</span><span class="p">()</span> </code></pre></div> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold">Model: "generator"</span> </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ ┃<span style="font-weight: bold"> Layer (type) </span>┃<span style="font-weight: bold"> Output Shape </span>┃<span style="font-weight: bold"> Param # </span>┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ │ dense_1 (<span style="color: #0087ff; text-decoration-color: #0087ff">Dense</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">8192</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">1,056,768</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ reshape (<span style="color: #0087ff; text-decoration-color: #0087ff">Reshape</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">8</span>, <span style="color: #00af00; text-decoration-color: #00af00">8</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_transpose │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">16</span>, <span style="color: #00af00; text-decoration-color: #00af00">16</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">262,272</span> │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv2DTranspose</span>) │ │ │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ leaky_re_lu_3 (<span style="color: #0087ff; text-decoration-color: #0087ff">LeakyReLU</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">16</span>, <span style="color: #00af00; text-decoration-color: #00af00">16</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_transpose_1 │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>, <span style="color: #00af00; text-decoration-color: #00af00">256</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">524,544</span> │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv2DTranspose</span>) │ │ │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ leaky_re_lu_4 (<span style="color: #0087ff; text-decoration-color: #0087ff">LeakyReLU</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>, <span style="color: #00af00; text-decoration-color: #00af00">256</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_transpose_2 │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>, <span style="color: #00af00; text-decoration-color: #00af00">512</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">2,097,664</span> │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv2DTranspose</span>) │ │ │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ leaky_re_lu_5 (<span style="color: #0087ff; text-decoration-color: #0087ff">LeakyReLU</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>, <span style="color: #00af00; text-decoration-color: #00af00">512</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_3 (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv2D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>, <span style="color: #00af00; text-decoration-color: #00af00">3</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">38,403</span> │ └─────────────────────────────────┴───────────────────────────┴────────────┘ </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Total params: </span><span style="color: #00af00; text-decoration-color: #00af00">3,979,651</span> (15.18 MB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">3,979,651</span> (15.18 MB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Non-trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">0</span> (0.00 B) </pre> <hr /> <h2 id="override-trainstep">Override <code>train_step</code></h2> <div class="codehilite"><pre><span></span><code><span class="k">class</span><span class="w"> </span><span class="nc">GAN</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">):</span> <span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">discriminator</span><span class="p">,</span> <span class="n">generator</span><span class="p">,</span> <span class="n">latent_dim</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator</span> <span class="o">=</span> <span class="n">discriminator</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator</span> <span class="o">=</span> <span class="n">generator</span> <span class="bp">self</span><span class="o">.</span><span class="n">latent_dim</span> <span class="o">=</span> <span class="n">latent_dim</span> <span class="bp">self</span><span class="o">.</span><span class="n">seed_generator</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">SeedGenerator</span><span class="p">(</span><span class="mi">1337</span><span class="p">)</span> <span class="k">def</span><span class="w"> </span><span class="nf">compile</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">d_optimizer</span><span class="p">,</span> <span class="n">g_optimizer</span><span class="p">,</span> <span class="n">loss_fn</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">compile</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_optimizer</span> <span class="o">=</span> <span class="n">d_optimizer</span> <span class="bp">self</span><span class="o">.</span><span class="n">g_optimizer</span> <span class="o">=</span> <span class="n">g_optimizer</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_fn</span> <span class="o">=</span> <span class="n">loss_fn</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_loss_metric</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">Mean</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">"d_loss"</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">g_loss_metric</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">Mean</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">"g_loss"</span><span class="p">)</span> <span class="nd">@property</span> <span class="k">def</span><span class="w"> </span><span class="nf">metrics</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="k">return</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">d_loss_metric</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">g_loss_metric</span><span class="p">]</span> <span class="k">def</span><span class="w"> </span><span class="nf">train_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">real_images</span><span class="p">):</span> <span class="c1"># Sample random points in the latent space</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">real_images</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span> <span class="n">random_latent_vectors</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">latent_dim</span><span class="p">),</span> <span class="n">seed</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">seed_generator</span> <span class="p">)</span> <span class="c1"># Decode them to fake images</span> <span class="n">generated_images</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator</span><span class="p">(</span><span class="n">random_latent_vectors</span><span class="p">)</span> <span class="c1"># Combine them with real images</span> <span class="n">combined_images</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">concatenate</span><span class="p">([</span><span class="n">generated_images</span><span class="p">,</span> <span class="n">real_images</span><span class="p">],</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="c1"># Assemble labels discriminating real from fake images</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">concatenate</span><span class="p">(</span> <span class="p">[</span><span class="n">ops</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">)),</span> <span class="n">ops</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">))],</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span> <span class="p">)</span> <span class="c1"># Add random noise to the labels - important trick!</span> <span class="n">labels</span> <span class="o">+=</span> <span class="mf">0.05</span> <span class="o">*</span> <span class="n">tf</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">labels</span><span class="p">))</span> <span class="c1"># Train the discriminator</span> <span class="k">with</span> <span class="n">tf</span><span class="o">.</span><span class="n">GradientTape</span><span class="p">()</span> <span class="k">as</span> <span class="n">tape</span><span class="p">:</span> <span class="n">predictions</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator</span><span class="p">(</span><span class="n">combined_images</span><span class="p">)</span> <span class="n">d_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_fn</span><span class="p">(</span><span class="n">labels</span><span class="p">,</span> <span class="n">predictions</span><span class="p">)</span> <span class="n">grads</span> <span class="o">=</span> <span class="n">tape</span><span class="o">.</span><span class="n">gradient</span><span class="p">(</span><span class="n">d_loss</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator</span><span class="o">.</span><span class="n">trainable_weights</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_optimizer</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span> <span class="nb">zip</span><span class="p">(</span><span class="n">grads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator</span><span class="o">.</span><span class="n">trainable_weights</span><span class="p">)</span> <span class="p">)</span> <span class="c1"># Sample random points in the latent space</span> <span class="n">random_latent_vectors</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">latent_dim</span><span class="p">),</span> <span class="n">seed</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">seed_generator</span> <span class="p">)</span> <span class="c1"># Assemble labels that say "all real images"</span> <span class="n">misleading_labels</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="c1"># Train the generator (note that we should *not* update the weights</span> <span class="c1"># of the discriminator)!</span> <span class="k">with</span> <span class="n">tf</span><span class="o">.</span><span class="n">GradientTape</span><span class="p">()</span> <span class="k">as</span> <span class="n">tape</span><span class="p">:</span> <span class="n">predictions</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">generator</span><span class="p">(</span><span class="n">random_latent_vectors</span><span class="p">))</span> <span class="n">g_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_fn</span><span class="p">(</span><span class="n">misleading_labels</span><span class="p">,</span> <span class="n">predictions</span><span class="p">)</span> <span class="n">grads</span> <span class="o">=</span> <span class="n">tape</span><span class="o">.</span><span class="n">gradient</span><span class="p">(</span><span class="n">g_loss</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator</span><span class="o">.</span><span class="n">trainable_weights</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">g_optimizer</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="n">grads</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator</span><span class="o">.</span><span class="n">trainable_weights</span><span class="p">))</span> <span class="c1"># Update metrics</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_loss_metric</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">d_loss</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">g_loss_metric</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">g_loss</span><span class="p">)</span> <span class="k">return</span> <span class="p">{</span> <span class="s2">"d_loss"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">d_loss_metric</span><span class="o">.</span><span class="n">result</span><span class="p">(),</span> <span class="s2">"g_loss"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">g_loss_metric</span><span class="o">.</span><span class="n">result</span><span class="p">(),</span> <span class="p">}</span> </code></pre></div> <hr /> <h2 id="create-a-callback-that-periodically-saves-generated-images">Create a callback that periodically saves generated images</h2> <div class="codehilite"><pre><span></span><code><span class="k">class</span><span class="w"> </span><span class="nc">GANMonitor</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">callbacks</span><span class="o">.</span><span class="n">Callback</span><span class="p">):</span> <span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">num_img</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">latent_dim</span><span class="o">=</span><span class="mi">128</span><span class="p">):</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_img</span> <span class="o">=</span> <span class="n">num_img</span> <span class="bp">self</span><span class="o">.</span><span class="n">latent_dim</span> <span class="o">=</span> <span class="n">latent_dim</span> <span class="bp">self</span><span class="o">.</span><span class="n">seed_generator</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">SeedGenerator</span><span class="p">(</span><span class="mi">42</span><span class="p">)</span> <span class="k">def</span><span class="w"> </span><span class="nf">on_epoch_end</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">epoch</span><span class="p">,</span> <span class="n">logs</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="n">random_latent_vectors</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_img</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">latent_dim</span><span class="p">),</span> <span class="n">seed</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">seed_generator</span> <span class="p">)</span> <span class="n">generated_images</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">generator</span><span class="p">(</span><span class="n">random_latent_vectors</span><span class="p">)</span> <span class="n">generated_images</span> <span class="o">*=</span> <span class="mi">255</span> <span class="n">generated_images</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_img</span><span class="p">):</span> <span class="n">img</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">array_to_img</span><span class="p">(</span><span class="n">generated_images</span><span class="p">[</span><span class="n">i</span><span class="p">])</span> <span class="n">img</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="s2">"generated_img_</span><span class="si">%03d</span><span class="s2">_</span><span class="si">%d</span><span class="s2">.png"</span> <span class="o">%</span> <span class="p">(</span><span class="n">epoch</span><span class="p">,</span> <span class="n">i</span><span class="p">))</span> </code></pre></div> <hr /> <h2 id="train-the-endtoend-model">Train the end-to-end model</h2> <div class="codehilite"><pre><span></span><code><span class="n">epochs</span> <span class="o">=</span> <span class="mi">1</span> <span class="c1"># In practice, use ~100 epochs</span> <span class="n">gan</span> <span class="o">=</span> <span class="n">GAN</span><span class="p">(</span><span class="n">discriminator</span><span class="o">=</span><span class="n">discriminator</span><span class="p">,</span> <span class="n">generator</span><span class="o">=</span><span class="n">generator</span><span class="p">,</span> <span class="n">latent_dim</span><span class="o">=</span><span class="n">latent_dim</span><span class="p">)</span> <span class="n">gan</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">d_optimizer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">learning_rate</span><span class="o">=</span><span class="mf">0.0001</span><span class="p">),</span> <span class="n">g_optimizer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">learning_rate</span><span class="o">=</span><span class="mf">0.0001</span><span class="p">),</span> <span class="n">loss_fn</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">BinaryCrossentropy</span><span class="p">(),</span> <span class="p">)</span> <span class="n">gan</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span> <span class="n">dataset</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="n">epochs</span><span class="p">,</span> <span class="n">callbacks</span><span class="o">=</span><span class="p">[</span><span class="n">GANMonitor</span><span class="p">(</span><span class="n">num_img</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">latent_dim</span><span class="o">=</span><span class="n">latent_dim</span><span class="p">)]</span> <span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 2/6332 [37m━━━━━━━━━━━━━━━━━━━━ 9:54 94ms/step - d_loss: 0.6792 - g_loss: 0.7880 WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1704214667.959762 1319 device_compiler.h:186] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process. 6332/6332 ━━━━━━━━━━━━━━━━━━━━ 557s 84ms/step - d_loss: 0.5616 - g_loss: 1.4099 <keras.src.callbacks.history.History at 0x7f251d32bc40> </code></pre></div> </div> <p>Some of the last generated images around epoch 30 (results keep improving after that):</p> <p><img alt="results" src="https://i.imgur.com/h5MtQZ7l.png" /></p> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#dcgan-to-generate-face-images'>DCGAN to generate face images</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#setup'>Setup</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#prepare-celeba-data'>Prepare CelebA data</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#create-the-discriminator'>Create the discriminator</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#create-the-generator'>Create the generator</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#override-trainstep'>Override <code>train_step</code></a> </div> <div class='k-outline-depth-2'> ◆ <a href='#create-a-callback-that-periodically-saves-generated-images'>Create a callback that periodically saves generated images</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#train-the-endtoend-model'>Train the end-to-end model</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>