CINXE.COM

GauGAN for conditional image generation

<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/generative/gaugan/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: GauGAN for conditional image generation"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: GauGAN for conditional image generation"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>GauGAN for conditional image generation</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink active" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink2" href="/examples/generative/ddim/">Denoising Diffusion Implicit Models</a> <a class="nav-sublink2" href="/examples/generative/random_walks_with_stable_diffusion_3/">A walk through latent space with Stable Diffusion 3</a> <a class="nav-sublink2" href="/examples/generative/dreambooth/">DreamBooth</a> <a class="nav-sublink2" href="/examples/generative/ddpm/">Denoising Diffusion Probabilistic Models</a> <a class="nav-sublink2" href="/examples/generative/fine_tune_via_textual_inversion/">Teach StableDiffusion new concepts via Textual Inversion</a> <a class="nav-sublink2" href="/examples/generative/finetune_stable_diffusion/">Fine-tuning Stable Diffusion</a> <a class="nav-sublink2" href="/examples/generative/vae/">Variational AutoEncoder</a> <a class="nav-sublink2" href="/examples/generative/dcgan_overriding_train_step/">GAN overriding Model.train_step</a> <a class="nav-sublink2" href="/examples/generative/wgan_gp/">WGAN-GP overriding Model.train_step</a> <a class="nav-sublink2" href="/examples/generative/conditional_gan/">Conditional GAN</a> <a class="nav-sublink2" href="/examples/generative/cyclegan/">CycleGAN</a> <a class="nav-sublink2" href="/examples/generative/gan_ada/">Data-efficient GANs with Adaptive Discriminator Augmentation</a> <a class="nav-sublink2" href="/examples/generative/deep_dream/">Deep Dream</a> <a class="nav-sublink2 active" href="/examples/generative/gaugan/">GauGAN for conditional image generation</a> <a class="nav-sublink2" href="/examples/generative/pixelcnn/">PixelCNN</a> <a class="nav-sublink2" href="/examples/generative/stylegan/">Face image generation with StyleGAN</a> <a class="nav-sublink2" href="/examples/generative/vq_vae/">Vector-Quantized Variational Autoencoders</a> <a class="nav-sublink2" href="/examples/generative/random_walks_with_stable_diffusion/">A walk through latent space with Stable Diffusion</a> <a class="nav-sublink2" href="/examples/generative/neural_style_transfer/">Neural style transfer</a> <a class="nav-sublink2" href="/examples/generative/adain/">Neural Style Transfer with AdaIN</a> <a class="nav-sublink2" href="/examples/generative/gpt2_text_generation_with_keras_hub/">GPT2 Text Generation with KerasHub</a> <a class="nav-sublink2" href="/examples/generative/text_generation_gpt/">GPT text generation from scratch with KerasHub</a> <a class="nav-sublink2" href="/examples/generative/text_generation_with_miniature_gpt/">Text generation with a miniature GPT</a> <a class="nav-sublink2" href="/examples/generative/lstm_character_level_text_generation/">Character-level text generation with LSTM</a> <a class="nav-sublink2" href="/examples/generative/text_generation_fnet/">Text Generation using FNet</a> <a class="nav-sublink2" href="/examples/generative/midi_generation_with_transformer/">Music Generation with Transformer Models</a> <a class="nav-sublink2" href="/examples/generative/molecule_generation/">Drug Molecule Generation with VAE</a> <a class="nav-sublink2" href="/examples/generative/wgan-graphs/">WGAN-GP with R-GCN for the generation of small molecular graphs</a> <a class="nav-sublink2" href="/examples/generative/real_nvp/">Density estimation using Real NVP</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparam Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/generative/'>Generative Deep Learning</a> / GauGAN for conditional image generation </div> <div class='k-content'> <h1 id="gaugan-for-conditional-image-generation">GauGAN for conditional image generation</h1> <p><strong>Author:</strong> <a href="https://github.com/soumik12345">Soumik Rakshit</a>, <a href="https://twitter.com/RisingSayak">Sayak Paul</a><br> <strong>Date created:</strong> 2021/12/26<br> <strong>Last modified:</strong> 2022/01/03<br> <strong>Description:</strong> Implementing a GauGAN for conditional image generation.</p> <div class='example_version_banner keras_3'>ⓘ This example uses Keras 3</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/generative/ipynb/gaugan.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/examples/generative/gaugan.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="introduction">Introduction</h2> <p>In this example, we present an implementation of the GauGAN architecture proposed in <a href="https://arxiv.org/abs/1903.07291">Semantic Image Synthesis with Spatially-Adaptive Normalization</a>. Briefly, GauGAN uses a Generative Adversarial Network (GAN) to generate realistic images that are conditioned on cue images and segmentation maps, as shown below (<a href="https://nvlabs.github.io/SPADE/">image source</a>):</p> <p><img alt="" src="https://i.ibb.co/p305dzv/image.png" /></p> <p>The main components of a GauGAN are:</p> <ul> <li><strong>SPADE (aka spatially-adaptive normalization)</strong> : The authors of GauGAN argue that the more conventional normalization layers (such as <a href="https://arxiv.org/abs/1502.03167">Batch Normalization</a>) destroy the semantic information obtained from segmentation maps that are provided as inputs. To address this problem, the authors introduce SPADE, a normalization layer particularly suitable for learning affine parameters (scale and bias) that are spatially adaptive. This is done by learning different sets of scaling and bias parameters for each semantic label.</li> <li><strong>Variational encoder</strong>: Inspired by <a href="https://arxiv.org/abs/1312.6114">Variational Autoencoders</a>, GauGAN uses a variational formulation wherein an encoder learns the mean and variance of a normal (Gaussian) distribution from the cue images. This is where GauGAN gets its name from. The generator of GauGAN takes as inputs the latents sampled from the Gaussian distribution as well as the one-hot encoded semantic segmentation label maps. The cue images act as style images that guide the generator to stylistic generation. This variational formulation helps GauGAN achieve image diversity as well as fidelity.</li> <li><strong>Multi-scale patch discriminator</strong> : Inspired by the <a href="https://paperswithcode.com/method/patchgan">PatchGAN</a> model, GauGAN uses a discriminator that assesses a given image on a patch basis and produces an averaged score.</li> </ul> <p>As we proceed with the example, we will discuss each of the different components in further detail.</p> <p>For a thorough review of GauGAN, please refer to <a href="https://blog.paperspace.com/nvidia-gaugan-introduction/">this article</a>. We also encourage you to check out <a href="https://nvlabs.github.io/SPADE/">the official GauGAN website</a>, which has many creative applications of GauGAN. This example assumes that the reader is already familiar with the fundamental concepts of GANs. If you need a refresher, the following resources might be useful:</p> <ul> <li><a href="https://livebook.manning.com/book/deep-learning-with-python/chapter-8">Chapter on GANs</a> from the Deep Learning with Python book by François Chollet.</li> <li>GAN implementations on keras.io:</li> </ul> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>* [Data efficient GANs](https://keras.io/examples/generative/gan_ada) * [CycleGAN](https://keras.io/examples/generative/cyclegan) * [Conditional GAN](https://keras.io/examples/generative/conditional_gan) </code></pre></div> </div> <hr /> <h2 id="data-collection">Data collection</h2> <p>We will be using the <a href="https://cmp.felk.cvut.cz/~tylecr1/facade/">Facades dataset</a> for training our GauGAN model. Let's first download it.</p> <div class="codehilite"><pre><span></span><code><span class="err">!</span><span class="n">wget</span> <span class="n">https</span><span class="p">:</span><span class="o">//</span><span class="n">drive</span><span class="o">.</span><span class="n">google</span><span class="o">.</span><span class="n">com</span><span class="o">/</span><span class="n">uc</span><span class="err">?</span><span class="nb">id</span><span class="o">=</span><span class="mi">1</span><span class="n">q4FEjQg1YSb4mPx2VdxL7LXKYu3voTMj</span> <span class="o">-</span><span class="n">O</span> <span class="n">facades_data</span><span class="o">.</span><span class="n">zip</span> <span class="err">!</span><span class="n">unzip</span> <span class="o">-</span><span class="n">q</span> <span class="n">facades_data</span><span class="o">.</span><span class="n">zip</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>--2024-01-11 22:46:32-- https://drive.google.com/uc?id=1q4FEjQg1YSb4mPx2VdxL7LXKYu3voTMj Resolving drive.google.com (drive.google.com)... 64.233.181.138, 64.233.181.102, 64.233.181.100, ... Connecting to drive.google.com (drive.google.com)|64.233.181.138|:443... connected. HTTP request sent, awaiting response... 303 See Other Location: https://drive.usercontent.google.com/download?id=1q4FEjQg1YSb4mPx2VdxL7LXKYu3voTMj [following] --2024-01-11 22:46:32-- https://drive.usercontent.google.com/download?id=1q4FEjQg1YSb4mPx2VdxL7LXKYu3voTMj Resolving drive.usercontent.google.com (drive.usercontent.google.com)... 108.177.112.132, 2607:f8b0:4001:c12::84 Connecting to drive.usercontent.google.com (drive.usercontent.google.com)|108.177.112.132|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 26036052 (25M) [application/octet-stream] Saving to: ‘facades_data.zip’ </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>facades_data.zip 100%[===================&gt;] 24.83M 94.3MB/s in 0.3s </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>2024-01-11 22:46:42 (94.3 MB/s) - ‘facades_data.zip’ saved [26036052/26036052] </code></pre></div> </div> <hr /> <h2 id="imports">Imports</h2> <div class="codehilite"><pre><span></span><code><span class="kn">import</span><span class="w"> </span><span class="nn">os</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s2">&quot;KERAS_BACKEND&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="s2">&quot;tensorflow&quot;</span> <span class="kn">import</span><span class="w"> </span><span class="nn">numpy</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">np</span> <span class="kn">import</span><span class="w"> </span><span class="nn">matplotlib.pyplot</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">plt</span> <span class="kn">import</span><span class="w"> </span><span class="nn">tensorflow</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">tf</span> <span class="kn">import</span><span class="w"> </span><span class="nn">keras</span> <span class="kn">from</span><span class="w"> </span><span class="nn">keras</span><span class="w"> </span><span class="kn">import</span> <span class="n">ops</span> <span class="kn">from</span><span class="w"> </span><span class="nn">keras</span><span class="w"> </span><span class="kn">import</span> <span class="n">layers</span> <span class="kn">from</span><span class="w"> </span><span class="nn">glob</span><span class="w"> </span><span class="kn">import</span> <span class="n">glob</span> </code></pre></div> <hr /> <h2 id="data-splitting">Data splitting</h2> <div class="codehilite"><pre><span></span><code><span class="n">PATH</span> <span class="o">=</span> <span class="s2">&quot;./facades_data/&quot;</span> <span class="n">SPLIT</span> <span class="o">=</span> <span class="mf">0.2</span> <span class="n">files</span> <span class="o">=</span> <span class="n">glob</span><span class="p">(</span><span class="n">PATH</span> <span class="o">+</span> <span class="s2">&quot;*.jpg&quot;</span><span class="p">)</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="n">files</span><span class="p">)</span> <span class="n">split_index</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">files</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">SPLIT</span><span class="p">))</span> <span class="n">train_files</span> <span class="o">=</span> <span class="n">files</span><span class="p">[:</span><span class="n">split_index</span><span class="p">]</span> <span class="n">val_files</span> <span class="o">=</span> <span class="n">files</span><span class="p">[</span><span class="n">split_index</span><span class="p">:]</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Total samples: </span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">files</span><span class="p">)</span><span class="si">}</span><span class="s2">.&quot;</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Total training samples: </span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">train_files</span><span class="p">)</span><span class="si">}</span><span class="s2">.&quot;</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Total validation samples: </span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">val_files</span><span class="p">)</span><span class="si">}</span><span class="s2">.&quot;</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Total samples: 378. Total training samples: 302. Total validation samples: 76. </code></pre></div> </div> <hr /> <h2 id="data-loader">Data loader</h2> <div class="codehilite"><pre><span></span><code><span class="n">BATCH_SIZE</span> <span class="o">=</span> <span class="mi">4</span> <span class="n">IMG_HEIGHT</span> <span class="o">=</span> <span class="n">IMG_WIDTH</span> <span class="o">=</span> <span class="mi">256</span> <span class="n">NUM_CLASSES</span> <span class="o">=</span> <span class="mi">12</span> <span class="n">AUTOTUNE</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">AUTOTUNE</span> <span class="k">def</span><span class="w"> </span><span class="nf">load</span><span class="p">(</span><span class="n">image_files</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">is_train</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span> <span class="k">def</span><span class="w"> </span><span class="nf">_random_crop</span><span class="p">(</span> <span class="n">segmentation_map</span><span class="p">,</span> <span class="n">image</span><span class="p">,</span> <span class="n">labels</span><span class="p">,</span> <span class="n">crop_size</span><span class="o">=</span><span class="p">(</span><span class="n">IMG_HEIGHT</span><span class="p">,</span> <span class="n">IMG_WIDTH</span><span class="p">),</span> <span class="p">):</span> <span class="n">crop_size</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">convert_to_tensor</span><span class="p">(</span><span class="n">crop_size</span><span class="p">)</span> <span class="n">image_shape</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">image</span><span class="p">)[:</span><span class="mi">2</span><span class="p">]</span> <span class="n">margins</span> <span class="o">=</span> <span class="n">image_shape</span> <span class="o">-</span> <span class="n">crop_size</span> <span class="n">y1</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(),</span> <span class="n">maxval</span><span class="o">=</span><span class="n">margins</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span> <span class="n">x1</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(),</span> <span class="n">maxval</span><span class="o">=</span><span class="n">margins</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span> <span class="n">y2</span> <span class="o">=</span> <span class="n">y1</span> <span class="o">+</span> <span class="n">crop_size</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="n">x2</span> <span class="o">=</span> <span class="n">x1</span> <span class="o">+</span> <span class="n">crop_size</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="n">cropped_images</span> <span class="o">=</span> <span class="p">[]</span> <span class="n">images</span> <span class="o">=</span> <span class="p">[</span><span class="n">segmentation_map</span><span class="p">,</span> <span class="n">image</span><span class="p">,</span> <span class="n">labels</span><span class="p">]</span> <span class="k">for</span> <span class="n">img</span> <span class="ow">in</span> <span class="n">images</span><span class="p">:</span> <span class="n">cropped_images</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">img</span><span class="p">[</span><span class="n">y1</span><span class="p">:</span><span class="n">y2</span><span class="p">,</span> <span class="n">x1</span><span class="p">:</span><span class="n">x2</span><span class="p">])</span> <span class="k">return</span> <span class="n">cropped_images</span> <span class="k">def</span><span class="w"> </span><span class="nf">_load_data_tf</span><span class="p">(</span><span class="n">image_file</span><span class="p">,</span> <span class="n">segmentation_map_file</span><span class="p">,</span> <span class="n">label_file</span><span class="p">):</span> <span class="n">image</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">decode_png</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">io</span><span class="o">.</span><span class="n">read_file</span><span class="p">(</span><span class="n">image_file</span><span class="p">),</span> <span class="n">channels</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span> <span class="n">segmentation_map</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">decode_png</span><span class="p">(</span> <span class="n">tf</span><span class="o">.</span><span class="n">io</span><span class="o">.</span><span class="n">read_file</span><span class="p">(</span><span class="n">segmentation_map_file</span><span class="p">),</span> <span class="n">channels</span><span class="o">=</span><span class="mi">3</span> <span class="p">)</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">decode_bmp</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">io</span><span class="o">.</span><span class="n">read_file</span><span class="p">(</span><span class="n">label_file</span><span class="p">),</span> <span class="n">channels</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="n">labels</span><span class="p">)</span> <span class="n">image</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="n">tf</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="o">/</span> <span class="mf">127.5</span> <span class="o">-</span> <span class="mi">1</span> <span class="n">segmentation_map</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">segmentation_map</span><span class="p">,</span> <span class="n">tf</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="o">/</span> <span class="mf">127.5</span> <span class="o">-</span> <span class="mi">1</span> <span class="k">return</span> <span class="n">segmentation_map</span><span class="p">,</span> <span class="n">image</span><span class="p">,</span> <span class="n">labels</span> <span class="k">def</span><span class="w"> </span><span class="nf">_one_hot</span><span class="p">(</span><span class="n">segmentation_maps</span><span class="p">,</span> <span class="n">real_images</span><span class="p">,</span> <span class="n">labels</span><span class="p">):</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">one_hot</span><span class="p">(</span><span class="n">labels</span><span class="p">,</span> <span class="n">NUM_CLASSES</span><span class="p">)</span> <span class="n">labels</span><span class="o">.</span><span class="n">set_shape</span><span class="p">((</span><span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="n">NUM_CLASSES</span><span class="p">))</span> <span class="k">return</span> <span class="n">segmentation_maps</span><span class="p">,</span> <span class="n">real_images</span><span class="p">,</span> <span class="n">labels</span> <span class="n">segmentation_map_files</span> <span class="o">=</span> <span class="p">[</span> <span class="n">image_file</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">&quot;images&quot;</span><span class="p">,</span> <span class="s2">&quot;segmentation_map&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">&quot;jpg&quot;</span><span class="p">,</span> <span class="s2">&quot;png&quot;</span><span class="p">)</span> <span class="k">for</span> <span class="n">image_file</span> <span class="ow">in</span> <span class="n">image_files</span> <span class="p">]</span> <span class="n">label_files</span> <span class="o">=</span> <span class="p">[</span> <span class="n">image_file</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">&quot;images&quot;</span><span class="p">,</span> <span class="s2">&quot;segmentation_labels&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">&quot;jpg&quot;</span><span class="p">,</span> <span class="s2">&quot;bmp&quot;</span><span class="p">)</span> <span class="k">for</span> <span class="n">image_file</span> <span class="ow">in</span> <span class="n">image_files</span> <span class="p">]</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">from_tensor_slices</span><span class="p">(</span> <span class="p">(</span><span class="n">image_files</span><span class="p">,</span> <span class="n">segmentation_map_files</span><span class="p">,</span> <span class="n">label_files</span><span class="p">)</span> <span class="p">)</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span> <span class="mi">10</span><span class="p">)</span> <span class="k">if</span> <span class="n">is_train</span> <span class="k">else</span> <span class="n">dataset</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">_load_data_tf</span><span class="p">,</span> <span class="n">num_parallel_calls</span><span class="o">=</span><span class="n">AUTOTUNE</span><span class="p">)</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">_random_crop</span><span class="p">,</span> <span class="n">num_parallel_calls</span><span class="o">=</span><span class="n">AUTOTUNE</span><span class="p">)</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">_one_hot</span><span class="p">,</span> <span class="n">num_parallel_calls</span><span class="o">=</span><span class="n">AUTOTUNE</span><span class="p">)</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">drop_remainder</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="k">return</span> <span class="n">dataset</span> <span class="n">train_dataset</span> <span class="o">=</span> <span class="n">load</span><span class="p">(</span><span class="n">train_files</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">BATCH_SIZE</span><span class="p">,</span> <span class="n">is_train</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="n">val_dataset</span> <span class="o">=</span> <span class="n">load</span><span class="p">(</span><span class="n">val_files</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">BATCH_SIZE</span><span class="p">,</span> <span class="n">is_train</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> </code></pre></div> <p>Now, let's visualize a few samples from the training set.</p> <div class="codehilite"><pre><span></span><code><span class="n">sample_train_batch</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="nb">iter</span><span class="p">(</span><span class="n">train_dataset</span><span class="p">))</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Segmentation map batch shape: </span><span class="si">{</span><span class="n">sample_train_batch</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2">.&quot;</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Image batch shape: </span><span class="si">{</span><span class="n">sample_train_batch</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2">.&quot;</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;One-hot encoded label map shape: </span><span class="si">{</span><span class="n">sample_train_batch</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2">.&quot;</span><span class="p">)</span> <span class="c1"># Plot a view samples from the training set.</span> <span class="k">for</span> <span class="n">segmentation_map</span><span class="p">,</span> <span class="n">real_image</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">sample_train_batch</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">sample_train_batch</span><span class="p">[</span><span class="mi">1</span><span class="p">]):</span> <span class="n">fig</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span> <span class="n">fig</span><span class="o">.</span><span class="n">add_subplot</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">&quot;Segmentation Map&quot;</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">((</span><span class="n">segmentation_map</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span><span class="p">)</span> <span class="n">fig</span><span class="o">.</span><span class="n">add_subplot</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">&quot;Real Image&quot;</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">((</span><span class="n">real_image</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Segmentation map batch shape: (4, 256, 256, 3). Image batch shape: (4, 256, 256, 3). One-hot encoded label map shape: (4, 256, 256, 12). </code></pre></div> </div> <p><img alt="png" src="/img/examples/generative/gaugan/gaugan_11_1.png" /></p> <p><img alt="png" src="/img/examples/generative/gaugan/gaugan_11_2.png" /></p> <p><img alt="png" src="/img/examples/generative/gaugan/gaugan_11_3.png" /></p> <p>Note that in the rest of this example, we use a couple of figures from the <a href="https://arxiv.org/abs/1903.07291">original GauGAN paper</a> for convenience.</p> <hr /> <h2 id="custom-layers">Custom layers</h2> <p>In the following section, we implement the following layers:</p> <ul> <li>SPADE</li> <li>Residual block including SPADE</li> <li>Gaussian sampler</li> </ul> <h3 id="some-more-notes-on-spade">Some more notes on SPADE</h3> <p><img alt="" src="https://i.imgur.com/DgMWrrs.png" /></p> <p><strong>SPatially-Adaptive (DE) normalization</strong> or <strong>SPADE</strong> is a simple but effective layer for synthesizing photorealistic images given an input semantic layout. Previous methods for conditional image generation from semantic input such as Pix2Pix (<a href="https://arxiv.org/abs/1611.07004">Isola et al.</a>) or Pix2PixHD (<a href="https://arxiv.org/abs/1711.11585">Wang et al.</a>) directly feed the semantic layout as input to the deep network, which is then processed through stacks of convolution, normalization, and nonlinearity layers. This is often suboptimal as the normalization layers have a tendency to wash away semantic information.</p> <p>In SPADE, the segmentation mask is first projected onto an embedding space, and then convolved to produce the modulation parameters <code>γ</code> and <code>β</code>. Unlike prior conditional normalization methods, <code>γ</code> and <code>β</code> are not vectors, but tensors with spatial dimensions. The produced <code>γ</code> and <code>β</code> are multiplied and added to the normalized activation element-wise. As the modulation parameters are adaptive to the input segmentation mask, SPADE is better suited for semantic image synthesis.</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span><span class="w"> </span><span class="nc">SPADE</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">filters</span><span class="p">,</span> <span class="n">epsilon</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">epsilon</span> <span class="o">=</span> <span class="n">epsilon</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s2">&quot;same&quot;</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">&quot;relu&quot;</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv_gamma</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="n">filters</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s2">&quot;same&quot;</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv_beta</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="n">filters</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s2">&quot;same&quot;</span><span class="p">)</span> <span class="k">def</span><span class="w"> </span><span class="nf">build</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_shape</span><span class="p">):</span> <span class="bp">self</span><span class="o">.</span><span class="n">resize_shape</span> <span class="o">=</span> <span class="n">input_shape</span><span class="p">[</span><span class="mi">1</span><span class="p">:</span><span class="mi">3</span><span class="p">]</span> <span class="k">def</span><span class="w"> </span><span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_tensor</span><span class="p">,</span> <span class="n">raw_mask</span><span class="p">):</span> <span class="n">mask</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">resize</span><span class="p">(</span><span class="n">raw_mask</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">resize_shape</span><span class="p">,</span> <span class="n">interpolation</span><span class="o">=</span><span class="s2">&quot;nearest&quot;</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv</span><span class="p">(</span><span class="n">mask</span><span class="p">)</span> <span class="n">gamma</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv_gamma</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">beta</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv_beta</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">mean</span><span class="p">,</span> <span class="n">var</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">moments</span><span class="p">(</span><span class="n">input_tensor</span><span class="p">,</span> <span class="n">axes</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span> <span class="n">keepdims</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="n">std</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">var</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">epsilon</span><span class="p">)</span> <span class="n">normalized</span> <span class="o">=</span> <span class="p">(</span><span class="n">input_tensor</span> <span class="o">-</span> <span class="n">mean</span><span class="p">)</span> <span class="o">/</span> <span class="n">std</span> <span class="n">output</span> <span class="o">=</span> <span class="n">gamma</span> <span class="o">*</span> <span class="n">normalized</span> <span class="o">+</span> <span class="n">beta</span> <span class="k">return</span> <span class="n">output</span> <span class="k">class</span><span class="w"> </span><span class="nc">ResBlock</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">filters</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">filters</span> <span class="o">=</span> <span class="n">filters</span> <span class="k">def</span><span class="w"> </span><span class="nf">build</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_shape</span><span class="p">):</span> <span class="n">input_filter</span> <span class="o">=</span> <span class="n">input_shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="bp">self</span><span class="o">.</span><span class="n">spade_1</span> <span class="o">=</span> <span class="n">SPADE</span><span class="p">(</span><span class="n">input_filter</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">spade_2</span> <span class="o">=</span> <span class="n">SPADE</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">filters</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv_1</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">filters</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s2">&quot;same&quot;</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv_2</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">filters</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s2">&quot;same&quot;</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">learned_skip</span> <span class="o">=</span> <span class="kc">False</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">filters</span> <span class="o">!=</span> <span class="n">input_filter</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">learned_skip</span> <span class="o">=</span> <span class="kc">True</span> <span class="bp">self</span><span class="o">.</span><span class="n">spade_3</span> <span class="o">=</span> <span class="n">SPADE</span><span class="p">(</span><span class="n">input_filter</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv_3</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">filters</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s2">&quot;same&quot;</span><span class="p">)</span> <span class="k">def</span><span class="w"> </span><span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_tensor</span><span class="p">,</span> <span class="n">mask</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">spade_1</span><span class="p">(</span><span class="n">input_tensor</span><span class="p">,</span> <span class="n">mask</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv_1</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">activations</span><span class="o">.</span><span class="n">leaky_relu</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mf">0.2</span><span class="p">))</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">spade_2</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv_2</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">activations</span><span class="o">.</span><span class="n">leaky_relu</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mf">0.2</span><span class="p">))</span> <span class="n">skip</span> <span class="o">=</span> <span class="p">(</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv_3</span><span class="p">(</span> <span class="n">keras</span><span class="o">.</span><span class="n">activations</span><span class="o">.</span><span class="n">leaky_relu</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">spade_3</span><span class="p">(</span><span class="n">input_tensor</span><span class="p">,</span> <span class="n">mask</span><span class="p">),</span> <span class="mf">0.2</span><span class="p">)</span> <span class="p">)</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">learned_skip</span> <span class="k">else</span> <span class="n">input_tensor</span> <span class="p">)</span> <span class="n">output</span> <span class="o">=</span> <span class="n">skip</span> <span class="o">+</span> <span class="n">x</span> <span class="k">return</span> <span class="n">output</span> <span class="k">class</span><span class="w"> </span><span class="nc">GaussianSampler</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">latent_dim</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">=</span> <span class="n">batch_size</span> <span class="bp">self</span><span class="o">.</span><span class="n">latent_dim</span> <span class="o">=</span> <span class="n">latent_dim</span> <span class="bp">self</span><span class="o">.</span><span class="n">seed_generator</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">SeedGenerator</span><span class="p">(</span><span class="mi">1337</span><span class="p">)</span> <span class="k">def</span><span class="w"> </span><span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="n">means</span><span class="p">,</span> <span class="n">variance</span> <span class="o">=</span> <span class="n">inputs</span> <span class="n">epsilon</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">latent_dim</span><span class="p">),</span> <span class="n">mean</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">stddev</span><span class="o">=</span><span class="mf">1.0</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">seed_generator</span><span class="p">,</span> <span class="p">)</span> <span class="n">samples</span> <span class="o">=</span> <span class="n">means</span> <span class="o">+</span> <span class="n">ops</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="mf">0.5</span> <span class="o">*</span> <span class="n">variance</span><span class="p">)</span> <span class="o">*</span> <span class="n">epsilon</span> <span class="k">return</span> <span class="n">samples</span> </code></pre></div> <p>Next, we implement the downsampling block for the encoder.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="nf">downsample</span><span class="p">(</span> <span class="n">channels</span><span class="p">,</span> <span class="n">kernels</span><span class="p">,</span> <span class="n">strides</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">apply_norm</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">apply_activation</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">apply_dropout</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="p">):</span> <span class="n">block</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">()</span> <span class="n">block</span><span class="o">.</span><span class="n">add</span><span class="p">(</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span> <span class="n">channels</span><span class="p">,</span> <span class="n">kernels</span><span class="p">,</span> <span class="n">strides</span><span class="o">=</span><span class="n">strides</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s2">&quot;same&quot;</span><span class="p">,</span> <span class="n">use_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">kernel_initializer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">initializers</span><span class="o">.</span><span class="n">GlorotNormal</span><span class="p">(),</span> <span class="p">)</span> <span class="p">)</span> <span class="k">if</span> <span class="n">apply_norm</span><span class="p">:</span> <span class="n">block</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">GroupNormalization</span><span class="p">(</span><span class="n">groups</span><span class="o">=-</span><span class="mi">1</span><span class="p">))</span> <span class="k">if</span> <span class="n">apply_activation</span><span class="p">:</span> <span class="n">block</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">(</span><span class="mf">0.2</span><span class="p">))</span> <span class="k">if</span> <span class="n">apply_dropout</span><span class="p">:</span> <span class="n">block</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="mf">0.5</span><span class="p">))</span> <span class="k">return</span> <span class="n">block</span> </code></pre></div> <p>The GauGAN encoder consists of a few downsampling blocks. It outputs the mean and variance of a distribution.</p> <p><img alt="" src="https://i.imgur.com/JgAv1EW.png" /></p> <div class="codehilite"><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="nf">build_encoder</span><span class="p">(</span><span class="n">image_shape</span><span class="p">,</span> <span class="n">encoder_downsample_factor</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">latent_dim</span><span class="o">=</span><span class="mi">256</span><span class="p">):</span> <span class="n">input_image</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="n">image_shape</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">downsample</span><span class="p">(</span><span class="n">encoder_downsample_factor</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">apply_norm</span><span class="o">=</span><span class="kc">False</span><span class="p">)(</span><span class="n">input_image</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">downsample</span><span class="p">(</span><span class="mi">2</span> <span class="o">*</span> <span class="n">encoder_downsample_factor</span><span class="p">,</span> <span class="mi">3</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">downsample</span><span class="p">(</span><span class="mi">4</span> <span class="o">*</span> <span class="n">encoder_downsample_factor</span><span class="p">,</span> <span class="mi">3</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">downsample</span><span class="p">(</span><span class="mi">8</span> <span class="o">*</span> <span class="n">encoder_downsample_factor</span><span class="p">,</span> <span class="mi">3</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">downsample</span><span class="p">(</span><span class="mi">8</span> <span class="o">*</span> <span class="n">encoder_downsample_factor</span><span class="p">,</span> <span class="mi">3</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Flatten</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> <span class="n">mean</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">latent_dim</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;mean&quot;</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">variance</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">latent_dim</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;variance&quot;</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">input_image</span><span class="p">,</span> <span class="p">[</span><span class="n">mean</span><span class="p">,</span> <span class="n">variance</span><span class="p">],</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;encoder&quot;</span><span class="p">)</span> </code></pre></div> <p>Next, we implement the generator, which consists of the modified residual blocks and upsampling blocks. It takes latent vectors and one-hot encoded segmentation labels, and produces new images.</p> <p><img alt="" src="https://i.imgur.com/9iP1TsB.png" /></p> <p>With SPADE, there is no need to feed the segmentation map to the first layer of the generator, since the latent inputs have enough structural information about the style we want the generator to emulate. We also discard the encoder part of the generator, which is commonly used in prior architectures. This results in a more lightweight generator network, which can also take a random vector as input, enabling a simple and natural path to multi-modal synthesis.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="nf">build_generator</span><span class="p">(</span><span class="n">mask_shape</span><span class="p">,</span> <span class="n">latent_dim</span><span class="o">=</span><span class="mi">256</span><span class="p">):</span> <span class="n">latent</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">latent_dim</span><span class="p">,))</span> <span class="n">mask</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="n">mask_shape</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">16384</span><span class="p">)(</span><span class="n">latent</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Reshape</span><span class="p">((</span><span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">1024</span><span class="p">))(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">ResBlock</span><span class="p">(</span><span class="n">filters</span><span class="o">=</span><span class="mi">1024</span><span class="p">)(</span><span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">UpSampling2D</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">))(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">ResBlock</span><span class="p">(</span><span class="n">filters</span><span class="o">=</span><span class="mi">1024</span><span class="p">)(</span><span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">UpSampling2D</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">))(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">ResBlock</span><span class="p">(</span><span class="n">filters</span><span class="o">=</span><span class="mi">1024</span><span class="p">)(</span><span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">UpSampling2D</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">))(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">ResBlock</span><span class="p">(</span><span class="n">filters</span><span class="o">=</span><span class="mi">512</span><span class="p">)(</span><span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">UpSampling2D</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">))(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">ResBlock</span><span class="p">(</span><span class="n">filters</span><span class="o">=</span><span class="mi">256</span><span class="p">)(</span><span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">UpSampling2D</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">))(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">ResBlock</span><span class="p">(</span><span class="n">filters</span><span class="o">=</span><span class="mi">128</span><span class="p">)(</span><span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">UpSampling2D</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">))(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">activations</span><span class="o">.</span><span class="n">leaky_relu</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mf">0.2</span><span class="p">)</span> <span class="n">output_image</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">activations</span><span class="o">.</span><span class="n">tanh</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s2">&quot;same&quot;</span><span class="p">)(</span><span class="n">x</span><span class="p">))</span> <span class="k">return</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">([</span><span class="n">latent</span><span class="p">,</span> <span class="n">mask</span><span class="p">],</span> <span class="n">output_image</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;generator&quot;</span><span class="p">)</span> </code></pre></div> <p>The discriminator takes a segmentation map and an image and concatenates them. It then predicts if patches of the concatenated image are real or fake.</p> <p><img alt="" src="https://i.imgur.com/rn71PlM.png" /></p> <div class="codehilite"><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="nf">build_discriminator</span><span class="p">(</span><span class="n">image_shape</span><span class="p">,</span> <span class="n">downsample_factor</span><span class="o">=</span><span class="mi">64</span><span class="p">):</span> <span class="n">input_image_A</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="n">image_shape</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;discriminator_image_A&quot;</span><span class="p">)</span> <span class="n">input_image_B</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="n">image_shape</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;discriminator_image_B&quot;</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Concatenate</span><span class="p">()([</span><span class="n">input_image_A</span><span class="p">,</span> <span class="n">input_image_B</span><span class="p">])</span> <span class="n">x1</span> <span class="o">=</span> <span class="n">downsample</span><span class="p">(</span><span class="n">downsample_factor</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="n">apply_norm</span><span class="o">=</span><span class="kc">False</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x2</span> <span class="o">=</span> <span class="n">downsample</span><span class="p">(</span><span class="mi">2</span> <span class="o">*</span> <span class="n">downsample_factor</span><span class="p">,</span> <span class="mi">4</span><span class="p">)(</span><span class="n">x1</span><span class="p">)</span> <span class="n">x3</span> <span class="o">=</span> <span class="n">downsample</span><span class="p">(</span><span class="mi">4</span> <span class="o">*</span> <span class="n">downsample_factor</span><span class="p">,</span> <span class="mi">4</span><span class="p">)(</span><span class="n">x2</span><span class="p">)</span> <span class="n">x4</span> <span class="o">=</span> <span class="n">downsample</span><span class="p">(</span><span class="mi">8</span> <span class="o">*</span> <span class="n">downsample_factor</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="n">strides</span><span class="o">=</span><span class="mi">1</span><span class="p">)(</span><span class="n">x3</span><span class="p">)</span> <span class="n">x5</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">4</span><span class="p">)(</span><span class="n">x4</span><span class="p">)</span> <span class="n">outputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">x1</span><span class="p">,</span> <span class="n">x2</span><span class="p">,</span> <span class="n">x3</span><span class="p">,</span> <span class="n">x4</span><span class="p">,</span> <span class="n">x5</span><span class="p">]</span> <span class="k">return</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">([</span><span class="n">input_image_A</span><span class="p">,</span> <span class="n">input_image_B</span><span class="p">],</span> <span class="n">outputs</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="loss-functions">Loss functions</h2> <p>GauGAN uses the following loss functions:</p> <ul> <li>Generator:</li> </ul> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>* Expectation over the discriminator predictions. * [KL divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence) for learning the mean and variance predicted by the encoder. * Minimization between the discriminator predictions on original and generated images to align the feature space of the generator. * [Perceptual loss](https://arxiv.org/abs/1603.08155) for encouraging the generated images to have perceptual quality. </code></pre></div> </div> <ul> <li>Discriminator:</li> </ul> <div class="codehilite"><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="nf">generator_loss</span><span class="p">(</span><span class="n">y</span><span class="p">):</span> <span class="k">return</span> <span class="o">-</span><span class="n">ops</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">y</span><span class="p">)</span> <span class="k">def</span><span class="w"> </span><span class="nf">kl_divergence_loss</span><span class="p">(</span><span class="n">mean</span><span class="p">,</span> <span class="n">variance</span><span class="p">):</span> <span class="k">return</span> <span class="o">-</span><span class="mf">0.5</span> <span class="o">*</span> <span class="n">ops</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="n">variance</span> <span class="o">-</span> <span class="n">ops</span><span class="o">.</span><span class="n">square</span><span class="p">(</span><span class="n">mean</span><span class="p">)</span> <span class="o">-</span> <span class="n">ops</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="n">variance</span><span class="p">))</span> <span class="k">class</span><span class="w"> </span><span class="nc">FeatureMatchingLoss</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">Loss</span><span class="p">):</span> <span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">mae</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">MeanAbsoluteError</span><span class="p">()</span> <span class="k">def</span><span class="w"> </span><span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">y_true</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">):</span> <span class="n">loss</span> <span class="o">=</span> <span class="mi">0</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">y_true</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span><span class="p">):</span> <span class="n">loss</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">mae</span><span class="p">(</span><span class="n">y_true</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">y_pred</span><span class="p">[</span><span class="n">i</span><span class="p">])</span> <span class="k">return</span> <span class="n">loss</span> <span class="k">class</span><span class="w"> </span><span class="nc">VGGFeatureMatchingLoss</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">Loss</span><span class="p">):</span> <span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder_layers</span> <span class="o">=</span> <span class="p">[</span> <span class="s2">&quot;block1_conv1&quot;</span><span class="p">,</span> <span class="s2">&quot;block2_conv1&quot;</span><span class="p">,</span> <span class="s2">&quot;block3_conv1&quot;</span><span class="p">,</span> <span class="s2">&quot;block4_conv1&quot;</span><span class="p">,</span> <span class="s2">&quot;block5_conv1&quot;</span><span class="p">,</span> <span class="p">]</span> <span class="bp">self</span><span class="o">.</span><span class="n">weights</span> <span class="o">=</span> <span class="p">[</span><span class="mf">1.0</span> <span class="o">/</span> <span class="mi">32</span><span class="p">,</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="mi">16</span><span class="p">,</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="mi">8</span><span class="p">,</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="mi">4</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">]</span> <span class="n">vgg</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">applications</span><span class="o">.</span><span class="n">VGG19</span><span class="p">(</span><span class="n">include_top</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">weights</span><span class="o">=</span><span class="s2">&quot;imagenet&quot;</span><span class="p">)</span> <span class="n">layer_outputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">vgg</span><span class="o">.</span><span class="n">get_layer</span><span class="p">(</span><span class="n">x</span><span class="p">)</span><span class="o">.</span><span class="n">output</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder_layers</span><span class="p">]</span> <span class="bp">self</span><span class="o">.</span><span class="n">vgg_model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">vgg</span><span class="o">.</span><span class="n">input</span><span class="p">,</span> <span class="n">layer_outputs</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;VGG&quot;</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">mae</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">MeanAbsoluteError</span><span class="p">()</span> <span class="k">def</span><span class="w"> </span><span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">y_true</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">):</span> <span class="n">y_true</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">applications</span><span class="o">.</span><span class="n">vgg19</span><span class="o">.</span><span class="n">preprocess_input</span><span class="p">(</span><span class="mf">127.5</span> <span class="o">*</span> <span class="p">(</span><span class="n">y_true</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span> <span class="n">y_pred</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">applications</span><span class="o">.</span><span class="n">vgg19</span><span class="o">.</span><span class="n">preprocess_input</span><span class="p">(</span><span class="mf">127.5</span> <span class="o">*</span> <span class="p">(</span><span class="n">y_pred</span> <span class="o">+</span> <span class="mi">1</span><span class="p">))</span> <span class="n">real_features</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">vgg_model</span><span class="p">(</span><span class="n">y_true</span><span class="p">)</span> <span class="n">fake_features</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">vgg_model</span><span class="p">(</span><span class="n">y_pred</span><span class="p">)</span> <span class="n">loss</span> <span class="o">=</span> <span class="mi">0</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">real_features</span><span class="p">)):</span> <span class="n">loss</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">weights</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">mae</span><span class="p">(</span><span class="n">real_features</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">fake_features</span><span class="p">[</span><span class="n">i</span><span class="p">])</span> <span class="k">return</span> <span class="n">loss</span> <span class="k">class</span><span class="w"> </span><span class="nc">DiscriminatorLoss</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">Loss</span><span class="p">):</span> <span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">hinge_loss</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">Hinge</span><span class="p">()</span> <span class="k">def</span><span class="w"> </span><span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">is_real</span><span class="p">):</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">hinge_loss</span><span class="p">(</span><span class="n">is_real</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>* [Hinge loss](https://en.wikipedia.org/wiki/Hinge_loss). </code></pre></div> </div> <hr /> <h2 id="gan-monitor-callback">GAN monitor callback</h2> <p>Next, we implement a callback to monitor the GauGAN results while it is training.</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span><span class="w"> </span><span class="nc">GanMonitor</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">callbacks</span><span class="o">.</span><span class="n">Callback</span><span class="p">):</span> <span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">val_dataset</span><span class="p">,</span> <span class="n">n_samples</span><span class="p">,</span> <span class="n">epoch_interval</span><span class="o">=</span><span class="mi">5</span><span class="p">):</span> <span class="bp">self</span><span class="o">.</span><span class="n">val_images</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="nb">iter</span><span class="p">(</span><span class="n">val_dataset</span><span class="p">))</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_samples</span> <span class="o">=</span> <span class="n">n_samples</span> <span class="bp">self</span><span class="o">.</span><span class="n">epoch_interval</span> <span class="o">=</span> <span class="n">epoch_interval</span> <span class="bp">self</span><span class="o">.</span><span class="n">seed_generator</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">SeedGenerator</span><span class="p">(</span><span class="mi">42</span><span class="p">)</span> <span class="k">def</span><span class="w"> </span><span class="nf">infer</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="n">latent_vector</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">latent_dim</span><span class="p">),</span> <span class="n">mean</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">stddev</span><span class="o">=</span><span class="mf">2.0</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">seed_generator</span><span class="p">,</span> <span class="p">)</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">predict</span><span class="p">([</span><span class="n">latent_vector</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">val_images</span><span class="p">[</span><span class="mi">2</span><span class="p">]])</span> <span class="k">def</span><span class="w"> </span><span class="nf">on_epoch_end</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">epoch</span><span class="p">,</span> <span class="n">logs</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="k">if</span> <span class="n">epoch</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">epoch_interval</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span> <span class="n">generated_images</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">infer</span><span class="p">()</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">n_samples</span><span class="p">):</span> <span class="n">grid_row</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">generated_images</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="mi">3</span><span class="p">)</span> <span class="n">f</span><span class="p">,</span> <span class="n">axarr</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">(</span><span class="n">grid_row</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">18</span><span class="p">,</span> <span class="n">grid_row</span> <span class="o">*</span> <span class="mi">6</span><span class="p">))</span> <span class="k">for</span> <span class="n">row</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">grid_row</span><span class="p">):</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">axarr</span> <span class="k">if</span> <span class="n">grid_row</span> <span class="o">==</span> <span class="mi">1</span> <span class="k">else</span> <span class="n">axarr</span><span class="p">[</span><span class="n">row</span><span class="p">]</span> <span class="n">ax</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">imshow</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">val_images</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="n">row</span><span class="p">]</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span><span class="p">)</span> <span class="n">ax</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">&quot;off&quot;</span><span class="p">)</span> <span class="n">ax</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">&quot;Mask&quot;</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">20</span><span class="p">)</span> <span class="n">ax</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">imshow</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">val_images</span><span class="p">[</span><span class="mi">1</span><span class="p">][</span><span class="n">row</span><span class="p">]</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span><span class="p">)</span> <span class="n">ax</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">&quot;off&quot;</span><span class="p">)</span> <span class="n">ax</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">&quot;Ground Truth&quot;</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">20</span><span class="p">)</span> <span class="n">ax</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">imshow</span><span class="p">((</span><span class="n">generated_images</span><span class="p">[</span><span class="n">row</span><span class="p">]</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span><span class="p">)</span> <span class="n">ax</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">&quot;off&quot;</span><span class="p">)</span> <span class="n">ax</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">&quot;Generated&quot;</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">20</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> </code></pre></div> <hr /> <h2 id="subclassed-gaugan-model">Subclassed GauGAN model</h2> <p>Finally, we put everything together inside a subclassed model (from <a href="https://www.tensorflow.org/api_docs/python/tf/keras/Model"><code>tf.keras.Model</code></a>) overriding its <code>train_step()</code> method.</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span><span class="w"> </span><span class="nc">GauGAN</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">):</span> <span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span> <span class="bp">self</span><span class="p">,</span> <span class="n">image_size</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">latent_dim</span><span class="p">,</span> <span class="n">feature_loss_coeff</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">vgg_feature_loss_coeff</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">kl_divergence_loss_coeff</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">,</span> <span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">image_size</span> <span class="o">=</span> <span class="n">image_size</span> <span class="bp">self</span><span class="o">.</span><span class="n">latent_dim</span> <span class="o">=</span> <span class="n">latent_dim</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">=</span> <span class="n">batch_size</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_classes</span> <span class="o">=</span> <span class="n">num_classes</span> <span class="bp">self</span><span class="o">.</span><span class="n">image_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">image_size</span><span class="p">,</span> <span class="n">image_size</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">mask_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">image_size</span><span class="p">,</span> <span class="n">image_size</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">feature_loss_coeff</span> <span class="o">=</span> <span class="n">feature_loss_coeff</span> <span class="bp">self</span><span class="o">.</span><span class="n">vgg_feature_loss_coeff</span> <span class="o">=</span> <span class="n">vgg_feature_loss_coeff</span> <span class="bp">self</span><span class="o">.</span><span class="n">kl_divergence_loss_coeff</span> <span class="o">=</span> <span class="n">kl_divergence_loss_coeff</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator</span> <span class="o">=</span> <span class="n">build_discriminator</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">image_shape</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator</span> <span class="o">=</span> <span class="n">build_generator</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">mask_shape</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span> <span class="o">=</span> <span class="n">build_encoder</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">image_shape</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">sampler</span> <span class="o">=</span> <span class="n">GaussianSampler</span><span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">latent_dim</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">patch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">combined_model</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">build_combined_generator</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">disc_loss_tracker</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">Mean</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">&quot;disc_loss&quot;</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">gen_loss_tracker</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">Mean</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">&quot;gen_loss&quot;</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">feat_loss_tracker</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">Mean</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">&quot;feat_loss&quot;</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">vgg_loss_tracker</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">Mean</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">&quot;vgg_loss&quot;</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">kl_loss_tracker</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">Mean</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">&quot;kl_loss&quot;</span><span class="p">)</span> <span class="nd">@property</span> <span class="k">def</span><span class="w"> </span><span class="nf">metrics</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="k">return</span> <span class="p">[</span> <span class="bp">self</span><span class="o">.</span><span class="n">disc_loss_tracker</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">gen_loss_tracker</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">feat_loss_tracker</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">vgg_loss_tracker</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">kl_loss_tracker</span><span class="p">,</span> <span class="p">]</span> <span class="k">def</span><span class="w"> </span><span class="nf">build_combined_generator</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="c1"># This method builds a model that takes as inputs the following:</span> <span class="c1"># latent vector, one-hot encoded segmentation label map, and</span> <span class="c1"># a segmentation map. It then (i) generates an image with the generator,</span> <span class="c1"># (ii) passes the generated images and segmentation map to the discriminator.</span> <span class="c1"># Finally, the model produces the following outputs: (a) discriminator outputs,</span> <span class="c1"># (b) generated image.</span> <span class="c1"># We will be using this model to simplify the implementation.</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator</span><span class="o">.</span><span class="n">trainable</span> <span class="o">=</span> <span class="kc">False</span> <span class="n">mask_input</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">mask_shape</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;mask&quot;</span><span class="p">)</span> <span class="n">image_input</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">image_shape</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;image&quot;</span><span class="p">)</span> <span class="n">latent_input</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">latent_dim</span><span class="p">,),</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;latent&quot;</span><span class="p">)</span> <span class="n">generated_image</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator</span><span class="p">([</span><span class="n">latent_input</span><span class="p">,</span> <span class="n">mask_input</span><span class="p">])</span> <span class="n">discriminator_output</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator</span><span class="p">([</span><span class="n">image_input</span><span class="p">,</span> <span class="n">generated_image</span><span class="p">])</span> <span class="n">combined_outputs</span> <span class="o">=</span> <span class="n">discriminator_output</span> <span class="o">+</span> <span class="p">[</span><span class="n">generated_image</span><span class="p">]</span> <span class="n">patch_size</span> <span class="o">=</span> <span class="n">discriminator_output</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="n">combined_model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span> <span class="p">[</span><span class="n">latent_input</span><span class="p">,</span> <span class="n">mask_input</span><span class="p">,</span> <span class="n">image_input</span><span class="p">],</span> <span class="n">combined_outputs</span> <span class="p">)</span> <span class="k">return</span> <span class="n">patch_size</span><span class="p">,</span> <span class="n">combined_model</span> <span class="k">def</span><span class="w"> </span><span class="nf">compile</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">gen_lr</span><span class="o">=</span><span class="mf">1e-4</span><span class="p">,</span> <span class="n">disc_lr</span><span class="o">=</span><span class="mf">4e-4</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator_optimizer</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span> <span class="n">gen_lr</span><span class="p">,</span> <span class="n">beta_1</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">beta_2</span><span class="o">=</span><span class="mf">0.999</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator_optimizer</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span> <span class="n">disc_lr</span><span class="p">,</span> <span class="n">beta_1</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">beta_2</span><span class="o">=</span><span class="mf">0.999</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator_loss</span> <span class="o">=</span> <span class="n">DiscriminatorLoss</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">feature_matching_loss</span> <span class="o">=</span> <span class="n">FeatureMatchingLoss</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">vgg_loss</span> <span class="o">=</span> <span class="n">VGGFeatureMatchingLoss</span><span class="p">()</span> <span class="k">def</span><span class="w"> </span><span class="nf">train_discriminator</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">latent_vector</span><span class="p">,</span> <span class="n">segmentation_map</span><span class="p">,</span> <span class="n">real_image</span><span class="p">,</span> <span class="n">labels</span><span class="p">):</span> <span class="n">fake_images</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator</span><span class="p">([</span><span class="n">latent_vector</span><span class="p">,</span> <span class="n">labels</span><span class="p">])</span> <span class="k">with</span> <span class="n">tf</span><span class="o">.</span><span class="n">GradientTape</span><span class="p">()</span> <span class="k">as</span> <span class="n">gradient_tape</span><span class="p">:</span> <span class="n">pred_fake</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator</span><span class="p">([</span><span class="n">segmentation_map</span><span class="p">,</span> <span class="n">fake_images</span><span class="p">])[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="n">pred_real</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator</span><span class="p">([</span><span class="n">segmentation_map</span><span class="p">,</span> <span class="n">real_image</span><span class="p">])[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="n">loss_fake</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator_loss</span><span class="p">(</span><span class="n">pred_fake</span><span class="p">,</span> <span class="o">-</span><span class="mf">1.0</span><span class="p">)</span> <span class="n">loss_real</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator_loss</span><span class="p">(</span><span class="n">pred_real</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">)</span> <span class="n">total_loss</span> <span class="o">=</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="p">(</span><span class="n">loss_fake</span> <span class="o">+</span> <span class="n">loss_real</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator</span><span class="o">.</span><span class="n">trainable</span> <span class="o">=</span> <span class="kc">True</span> <span class="n">gradients</span> <span class="o">=</span> <span class="n">gradient_tape</span><span class="o">.</span><span class="n">gradient</span><span class="p">(</span> <span class="n">total_loss</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator</span><span class="o">.</span><span class="n">trainable_variables</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator_optimizer</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span> <span class="nb">zip</span><span class="p">(</span><span class="n">gradients</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator</span><span class="o">.</span><span class="n">trainable_variables</span><span class="p">)</span> <span class="p">)</span> <span class="k">return</span> <span class="n">total_loss</span> <span class="k">def</span><span class="w"> </span><span class="nf">train_generator</span><span class="p">(</span> <span class="bp">self</span><span class="p">,</span> <span class="n">latent_vector</span><span class="p">,</span> <span class="n">segmentation_map</span><span class="p">,</span> <span class="n">labels</span><span class="p">,</span> <span class="n">image</span><span class="p">,</span> <span class="n">mean</span><span class="p">,</span> <span class="n">variance</span> <span class="p">):</span> <span class="c1"># Generator learns through the signal provided by the discriminator. During</span> <span class="c1"># backpropagation, we only update the generator parameters.</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator</span><span class="o">.</span><span class="n">trainable</span> <span class="o">=</span> <span class="kc">False</span> <span class="k">with</span> <span class="n">tf</span><span class="o">.</span><span class="n">GradientTape</span><span class="p">()</span> <span class="k">as</span> <span class="n">tape</span><span class="p">:</span> <span class="n">real_d_output</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator</span><span class="p">([</span><span class="n">segmentation_map</span><span class="p">,</span> <span class="n">image</span><span class="p">])</span> <span class="n">combined_outputs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">combined_model</span><span class="p">(</span> <span class="p">[</span><span class="n">latent_vector</span><span class="p">,</span> <span class="n">labels</span><span class="p">,</span> <span class="n">segmentation_map</span><span class="p">]</span> <span class="p">)</span> <span class="n">fake_d_output</span><span class="p">,</span> <span class="n">fake_image</span> <span class="o">=</span> <span class="n">combined_outputs</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">combined_outputs</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="n">pred</span> <span class="o">=</span> <span class="n">fake_d_output</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="c1"># Compute generator losses.</span> <span class="n">g_loss</span> <span class="o">=</span> <span class="n">generator_loss</span><span class="p">(</span><span class="n">pred</span><span class="p">)</span> <span class="n">kl_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">kl_divergence_loss_coeff</span> <span class="o">*</span> <span class="n">kl_divergence_loss</span><span class="p">(</span><span class="n">mean</span><span class="p">,</span> <span class="n">variance</span><span class="p">)</span> <span class="n">vgg_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">vgg_feature_loss_coeff</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">vgg_loss</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="n">fake_image</span><span class="p">)</span> <span class="n">feature_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">feature_loss_coeff</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">feature_matching_loss</span><span class="p">(</span> <span class="n">real_d_output</span><span class="p">,</span> <span class="n">fake_d_output</span> <span class="p">)</span> <span class="n">total_loss</span> <span class="o">=</span> <span class="n">g_loss</span> <span class="o">+</span> <span class="n">kl_loss</span> <span class="o">+</span> <span class="n">vgg_loss</span> <span class="o">+</span> <span class="n">feature_loss</span> <span class="n">all_trainable_variables</span> <span class="o">=</span> <span class="p">(</span> <span class="bp">self</span><span class="o">.</span><span class="n">combined_model</span><span class="o">.</span><span class="n">trainable_variables</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="o">.</span><span class="n">trainable_variables</span> <span class="p">)</span> <span class="n">gradients</span> <span class="o">=</span> <span class="n">tape</span><span class="o">.</span><span class="n">gradient</span><span class="p">(</span><span class="n">total_loss</span><span class="p">,</span> <span class="n">all_trainable_variables</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator_optimizer</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span> <span class="nb">zip</span><span class="p">(</span><span class="n">gradients</span><span class="p">,</span> <span class="n">all_trainable_variables</span><span class="p">)</span> <span class="p">)</span> <span class="k">return</span> <span class="n">total_loss</span><span class="p">,</span> <span class="n">feature_loss</span><span class="p">,</span> <span class="n">vgg_loss</span><span class="p">,</span> <span class="n">kl_loss</span> <span class="k">def</span><span class="w"> </span><span class="nf">train_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">):</span> <span class="n">segmentation_map</span><span class="p">,</span> <span class="n">image</span><span class="p">,</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">data</span> <span class="n">mean</span><span class="p">,</span> <span class="n">variance</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="p">(</span><span class="n">image</span><span class="p">)</span> <span class="n">latent_vector</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sampler</span><span class="p">([</span><span class="n">mean</span><span class="p">,</span> <span class="n">variance</span><span class="p">])</span> <span class="n">discriminator_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_discriminator</span><span class="p">(</span> <span class="n">latent_vector</span><span class="p">,</span> <span class="n">segmentation_map</span><span class="p">,</span> <span class="n">image</span><span class="p">,</span> <span class="n">labels</span> <span class="p">)</span> <span class="p">(</span><span class="n">generator_loss</span><span class="p">,</span> <span class="n">feature_loss</span><span class="p">,</span> <span class="n">vgg_loss</span><span class="p">,</span> <span class="n">kl_loss</span><span class="p">)</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">train_generator</span><span class="p">(</span> <span class="n">latent_vector</span><span class="p">,</span> <span class="n">segmentation_map</span><span class="p">,</span> <span class="n">labels</span><span class="p">,</span> <span class="n">image</span><span class="p">,</span> <span class="n">mean</span><span class="p">,</span> <span class="n">variance</span> <span class="p">)</span> <span class="c1"># Report progress.</span> <span class="bp">self</span><span class="o">.</span><span class="n">disc_loss_tracker</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">discriminator_loss</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">gen_loss_tracker</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">generator_loss</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">feat_loss_tracker</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">feature_loss</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">vgg_loss_tracker</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">vgg_loss</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">kl_loss_tracker</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">kl_loss</span><span class="p">)</span> <span class="n">results</span> <span class="o">=</span> <span class="p">{</span><span class="n">m</span><span class="o">.</span><span class="n">name</span><span class="p">:</span> <span class="n">m</span><span class="o">.</span><span class="n">result</span><span class="p">()</span> <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">metrics</span><span class="p">}</span> <span class="k">return</span> <span class="n">results</span> <span class="k">def</span><span class="w"> </span><span class="nf">test_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">):</span> <span class="n">segmentation_map</span><span class="p">,</span> <span class="n">image</span><span class="p">,</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">data</span> <span class="c1"># Obtain the learned moments of the real image distribution.</span> <span class="n">mean</span><span class="p">,</span> <span class="n">variance</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="p">(</span><span class="n">image</span><span class="p">)</span> <span class="c1"># Sample a latent from the distribution defined by the learned moments.</span> <span class="n">latent_vector</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">sampler</span><span class="p">([</span><span class="n">mean</span><span class="p">,</span> <span class="n">variance</span><span class="p">])</span> <span class="c1"># Generate the fake images.</span> <span class="n">fake_images</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator</span><span class="p">([</span><span class="n">latent_vector</span><span class="p">,</span> <span class="n">labels</span><span class="p">])</span> <span class="c1"># Calculate the losses.</span> <span class="n">pred_fake</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator</span><span class="p">([</span><span class="n">segmentation_map</span><span class="p">,</span> <span class="n">fake_images</span><span class="p">])[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="n">pred_real</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator</span><span class="p">([</span><span class="n">segmentation_map</span><span class="p">,</span> <span class="n">image</span><span class="p">])[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="n">loss_fake</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator_loss</span><span class="p">(</span><span class="n">pred_fake</span><span class="p">,</span> <span class="o">-</span><span class="mf">1.0</span><span class="p">)</span> <span class="n">loss_real</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator_loss</span><span class="p">(</span><span class="n">pred_real</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">)</span> <span class="n">total_discriminator_loss</span> <span class="o">=</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="p">(</span><span class="n">loss_fake</span> <span class="o">+</span> <span class="n">loss_real</span><span class="p">)</span> <span class="n">real_d_output</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">discriminator</span><span class="p">([</span><span class="n">segmentation_map</span><span class="p">,</span> <span class="n">image</span><span class="p">])</span> <span class="n">combined_outputs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">combined_model</span><span class="p">(</span> <span class="p">[</span><span class="n">latent_vector</span><span class="p">,</span> <span class="n">labels</span><span class="p">,</span> <span class="n">segmentation_map</span><span class="p">]</span> <span class="p">)</span> <span class="n">fake_d_output</span><span class="p">,</span> <span class="n">fake_image</span> <span class="o">=</span> <span class="n">combined_outputs</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">combined_outputs</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="n">pred</span> <span class="o">=</span> <span class="n">fake_d_output</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="n">g_loss</span> <span class="o">=</span> <span class="n">generator_loss</span><span class="p">(</span><span class="n">pred</span><span class="p">)</span> <span class="n">kl_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">kl_divergence_loss_coeff</span> <span class="o">*</span> <span class="n">kl_divergence_loss</span><span class="p">(</span><span class="n">mean</span><span class="p">,</span> <span class="n">variance</span><span class="p">)</span> <span class="n">vgg_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">vgg_feature_loss_coeff</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">vgg_loss</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="n">fake_image</span><span class="p">)</span> <span class="n">feature_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">feature_loss_coeff</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">feature_matching_loss</span><span class="p">(</span> <span class="n">real_d_output</span><span class="p">,</span> <span class="n">fake_d_output</span> <span class="p">)</span> <span class="n">total_generator_loss</span> <span class="o">=</span> <span class="n">g_loss</span> <span class="o">+</span> <span class="n">kl_loss</span> <span class="o">+</span> <span class="n">vgg_loss</span> <span class="o">+</span> <span class="n">feature_loss</span> <span class="c1"># Report progress.</span> <span class="bp">self</span><span class="o">.</span><span class="n">disc_loss_tracker</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">total_discriminator_loss</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">gen_loss_tracker</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">total_generator_loss</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">feat_loss_tracker</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">feature_loss</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">vgg_loss_tracker</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">vgg_loss</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">kl_loss_tracker</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">kl_loss</span><span class="p">)</span> <span class="n">results</span> <span class="o">=</span> <span class="p">{</span><span class="n">m</span><span class="o">.</span><span class="n">name</span><span class="p">:</span> <span class="n">m</span><span class="o">.</span><span class="n">result</span><span class="p">()</span> <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">metrics</span><span class="p">}</span> <span class="k">return</span> <span class="n">results</span> <span class="k">def</span><span class="w"> </span><span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="n">latent_vectors</span><span class="p">,</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">inputs</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">generator</span><span class="p">([</span><span class="n">latent_vectors</span><span class="p">,</span> <span class="n">labels</span><span class="p">])</span> </code></pre></div> <hr /> <h2 id="gaugan-training">GauGAN training</h2> <div class="codehilite"><pre><span></span><code><span class="n">gaugan</span> <span class="o">=</span> <span class="n">GauGAN</span><span class="p">(</span><span class="n">IMG_HEIGHT</span><span class="p">,</span> <span class="n">NUM_CLASSES</span><span class="p">,</span> <span class="n">BATCH_SIZE</span><span class="p">,</span> <span class="n">latent_dim</span><span class="o">=</span><span class="mi">256</span><span class="p">)</span> <span class="n">gaugan</span><span class="o">.</span><span class="n">compile</span><span class="p">()</span> <span class="n">history</span> <span class="o">=</span> <span class="n">gaugan</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span> <span class="n">train_dataset</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="n">val_dataset</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">15</span><span class="p">,</span> <span class="n">callbacks</span><span class="o">=</span><span class="p">[</span><span class="n">GanMonitor</span><span class="p">(</span><span class="n">val_dataset</span><span class="p">,</span> <span class="n">BATCH_SIZE</span><span class="p">)],</span> <span class="p">)</span> <span class="k">def</span><span class="w"> </span><span class="nf">plot_history</span><span class="p">(</span><span class="n">item</span><span class="p">):</span> <span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">history</span><span class="o">.</span><span class="n">history</span><span class="p">[</span><span class="n">item</span><span class="p">],</span> <span class="n">label</span><span class="o">=</span><span class="n">item</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">history</span><span class="o">.</span><span class="n">history</span><span class="p">[</span><span class="s2">&quot;val_&quot;</span> <span class="o">+</span> <span class="n">item</span><span class="p">],</span> <span class="n">label</span><span class="o">=</span><span class="s2">&quot;val_&quot;</span> <span class="o">+</span> <span class="n">item</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s2">&quot;Epochs&quot;</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">ylabel</span><span class="p">(</span><span class="n">item</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="s2">&quot;Train and Validation </span><span class="si">{}</span><span class="s2"> Over Epochs&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">item</span><span class="p">),</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">14</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">legend</span><span class="p">()</span> <span class="n">plt</span><span class="o">.</span><span class="n">grid</span><span class="p">()</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> <span class="n">plot_history</span><span class="p">(</span><span class="s2">&quot;disc_loss&quot;</span><span class="p">)</span> <span class="n">plot_history</span><span class="p">(</span><span class="s2">&quot;gen_loss&quot;</span><span class="p">)</span> <span class="n">plot_history</span><span class="p">(</span><span class="s2">&quot;feat_loss&quot;</span><span class="p">)</span> <span class="n">plot_history</span><span class="p">(</span><span class="s2">&quot;vgg_loss&quot;</span><span class="p">)</span> <span class="n">plot_history</span><span class="p">(</span><span class="s2">&quot;kl_loss&quot;</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 1/15 /home/sineeli/anaconda3/envs/kerasv3/lib/python3.10/site-packages/keras/src/optimizers/base_optimizer.py:472: UserWarning: Gradients do not exist for variables [&#39;kernel&#39;, &#39;kernel&#39;, &#39;gamma&#39;, &#39;beta&#39;, &#39;kernel&#39;, &#39;gamma&#39;, &#39;beta&#39;, &#39;kernel&#39;, &#39;gamma&#39;, &#39;beta&#39;, &#39;kernel&#39;, &#39;gamma&#39;, &#39;beta&#39;, &#39;kernel&#39;, &#39;bias&#39;, &#39;kernel&#39;, &#39;bias&#39;] when minimizing the loss. If using `model.compile()`, did you forget to provide a `loss` argument? warnings.warn( WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1705013303.976306 30381 device_compiler.h:186] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process. W0000 00:00:1705013304.021899 30381 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update 75/75 ━━━━━━━━━━━━━━━━━━━━ 0s 176ms/step - disc_loss: 1.3079 - feat_loss: 11.2902 - gen_loss: 113.0583 - kl_loss: 83.1424 - vgg_loss: 18.4966 W0000 00:00:1705013326.657730 30384 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update 1/1 ━━━━━━━━━━━━━━━━━━━━ 3s 3s/step </code></pre></div> </div> <p><img alt="png" src="/img/examples/generative/gaugan/gaugan_31_5.png" /></p> <p><img alt="png" src="/img/examples/generative/gaugan/gaugan_31_6.png" /></p> <p><img alt="png" src="/img/examples/generative/gaugan/gaugan_31_7.png" /></p> <p><img alt="png" src="/img/examples/generative/gaugan/gaugan_31_8.png" /></p> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 75/75 ━━━━━━━━━━━━━━━━━━━━ 114s 426ms/step - disc_loss: 1.3051 - feat_loss: 11.2902 - gen_loss: 113.0590 - kl_loss: 83.1493 - vgg_loss: 18.4890 - val_disc_loss: 1.0374 - val_feat_loss: 9.2344 - val_gen_loss: 110.1001 - val_kl_loss: 83.8935 - val_vgg_loss: 16.6412 Epoch 2/15 75/75 ━━━━━━━━━━━━━━━━━━━━ 14s 193ms/step - disc_loss: 0.8257 - feat_loss: 12.6603 - gen_loss: 115.9798 - kl_loss: 84.4545 - vgg_loss: 18.2973 - val_disc_loss: 0.9296 - val_feat_loss: 10.4162 - val_gen_loss: 110.6182 - val_kl_loss: 83.4473 - val_vgg_loss: 16.5499 Epoch 3/15 75/75 ━━━━━━━━━━━━━━━━━━━━ 15s 194ms/step - disc_loss: 0.9126 - feat_loss: 10.4992 - gen_loss: 111.6962 - kl_loss: 83.8692 - vgg_loss: 17.0433 - val_disc_loss: 0.8875 - val_feat_loss: 9.9899 - val_gen_loss: 111.4879 - val_kl_loss: 84.6905 - val_vgg_loss: 16.4510 Epoch 4/15 75/75 ━━━━━━━━━━━━━━━━━━━━ 15s 194ms/step - disc_loss: 0.8975 - feat_loss: 9.9081 - gen_loss: 111.2489 - kl_loss: 84.3098 - vgg_loss: 16.7369 - val_disc_loss: 0.9266 - val_feat_loss: 8.8318 - val_gen_loss: 107.9712 - val_kl_loss: 82.1354 - val_vgg_loss: 16.2676 Epoch 5/15 75/75 ━━━━━━━━━━━━━━━━━━━━ 15s 194ms/step - disc_loss: 0.9378 - feat_loss: 9.1914 - gen_loss: 110.5359 - kl_loss: 84.7988 - vgg_loss: 16.3160 - val_disc_loss: 1.0073 - val_feat_loss: 8.9351 - val_gen_loss: 109.2667 - val_kl_loss: 84.4920 - val_vgg_loss: 16.3844 Epoch 6/15 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 35ms/step </code></pre></div> </div> <p><img alt="png" src="/img/examples/generative/gaugan/gaugan_31_10.png" /></p> <p><img alt="png" src="/img/examples/generative/gaugan/gaugan_31_11.png" /></p> <p><img alt="png" src="/img/examples/generative/gaugan/gaugan_31_12.png" /></p> <p><img alt="png" src="/img/examples/generative/gaugan/gaugan_31_13.png" /></p> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 75/75 ━━━━━━━━━━━━━━━━━━━━ 19s 258ms/step - disc_loss: 0.8982 - feat_loss: 9.2486 - gen_loss: 109.9399 - kl_loss: 83.8095 - vgg_loss: 16.5587 - val_disc_loss: 0.8061 - val_feat_loss: 8.5935 - val_gen_loss: 109.5937 - val_kl_loss: 84.5844 - val_vgg_loss: 15.8794 Epoch 7/15 75/75 ━━━━━━━━━━━━━━━━━━━━ 15s 194ms/step - disc_loss: 0.9048 - feat_loss: 9.1064 - gen_loss: 109.3803 - kl_loss: 83.8245 - vgg_loss: 16.0975 - val_disc_loss: 1.0096 - val_feat_loss: 7.6335 - val_gen_loss: 108.2900 - val_kl_loss: 84.8679 - val_vgg_loss: 15.9580 Epoch 8/15 75/75 ━━━━━━━━━━━━━━━━━━━━ 15s 193ms/step - disc_loss: 0.9075 - feat_loss: 8.0537 - gen_loss: 108.1771 - kl_loss: 83.6673 - vgg_loss: 16.1545 - val_disc_loss: 1.0090 - val_feat_loss: 8.7077 - val_gen_loss: 109.2079 - val_kl_loss: 84.5022 - val_vgg_loss: 16.3814 Epoch 9/15 75/75 ━━━━━━━━━━━━━━━━━━━━ 15s 194ms/step - disc_loss: 0.9053 - feat_loss: 7.7949 - gen_loss: 107.9268 - kl_loss: 83.6504 - vgg_loss: 16.1193 - val_disc_loss: 1.0663 - val_feat_loss: 8.2042 - val_gen_loss: 108.4819 - val_kl_loss: 84.5961 - val_vgg_loss: 16.0834 Epoch 10/15 75/75 ━━━━━━━━━━━━━━━━━━━━ 15s 194ms/step - disc_loss: 0.8905 - feat_loss: 7.7652 - gen_loss: 108.3079 - kl_loss: 83.8574 - vgg_loss: 16.2992 - val_disc_loss: 0.8362 - val_feat_loss: 7.7127 - val_gen_loss: 108.9906 - val_kl_loss: 84.4822 - val_vgg_loss: 16.0521 Epoch 11/15 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 30ms/step </code></pre></div> </div> <p><img alt="png" src="/img/examples/generative/gaugan/gaugan_31_15.png" /></p> <p><img alt="png" src="/img/examples/generative/gaugan/gaugan_31_16.png" /></p> <p><img alt="png" src="/img/examples/generative/gaugan/gaugan_31_17.png" /></p> <p><img alt="png" src="/img/examples/generative/gaugan/gaugan_31_18.png" /></p> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 75/75 ━━━━━━━━━━━━━━━━━━━━ 20s 263ms/step - disc_loss: 0.9047 - feat_loss: 7.5019 - gen_loss: 107.6317 - kl_loss: 83.6812 - vgg_loss: 16.1292 - val_disc_loss: 0.8788 - val_feat_loss: 7.7651 - val_gen_loss: 109.1731 - val_kl_loss: 84.3094 - val_vgg_loss: 16.0356 Epoch 12/15 75/75 ━━━━━━━━━━━━━━━━━━━━ 15s 194ms/step - disc_loss: 0.8899 - feat_loss: 7.5799 - gen_loss: 108.2313 - kl_loss: 84.4031 - vgg_loss: 15.9665 - val_disc_loss: 0.8358 - val_feat_loss: 7.5676 - val_gen_loss: 109.5789 - val_kl_loss: 85.7282 - val_vgg_loss: 16.0442 Epoch 13/15 75/75 ━━━━━━━━━━━━━━━━━━━━ 15s 194ms/step - disc_loss: 0.8542 - feat_loss: 7.3362 - gen_loss: 107.4649 - kl_loss: 83.6942 - vgg_loss: 16.0675 - val_disc_loss: 1.0853 - val_feat_loss: 7.9020 - val_gen_loss: 106.9958 - val_kl_loss: 84.2610 - val_vgg_loss: 15.8510 Epoch 14/15 75/75 ━━━━━━━━━━━━━━━━━━━━ 15s 194ms/step - disc_loss: 0.8631 - feat_loss: 7.6403 - gen_loss: 108.6401 - kl_loss: 84.5304 - vgg_loss: 16.0426 - val_disc_loss: 0.9516 - val_feat_loss: 8.8795 - val_gen_loss: 108.5215 - val_kl_loss: 83.1849 - val_vgg_loss: 16.3289 Epoch 15/15 75/75 ━━━━━━━━━━━━━━━━━━━━ 15s 194ms/step - disc_loss: 0.8939 - feat_loss: 7.5489 - gen_loss: 108.8330 - kl_loss: 85.0358 - vgg_loss: 15.9147 - val_disc_loss: 0.9616 - val_feat_loss: 8.0080 - val_gen_loss: 108.1650 - val_kl_loss: 84.7754 - val_vgg_loss: 15.9561 </code></pre></div> </div> <p><img alt="png" src="/img/examples/generative/gaugan/gaugan_31_20.png" /></p> <p><img alt="png" src="/img/examples/generative/gaugan/gaugan_31_21.png" /></p> <p><img alt="png" src="/img/examples/generative/gaugan/gaugan_31_22.png" /></p> <p><img alt="png" src="/img/examples/generative/gaugan/gaugan_31_23.png" /></p> <p><img alt="png" src="/img/examples/generative/gaugan/gaugan_31_24.png" /></p> <hr /> <h2 id="inference">Inference</h2> <div class="codehilite"><pre><span></span><code><span class="n">val_iterator</span> <span class="o">=</span> <span class="nb">iter</span><span class="p">(</span><span class="n">val_dataset</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">5</span><span class="p">):</span> <span class="n">val_images</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="n">val_iterator</span><span class="p">)</span> <span class="c1"># Sample latent from a normal distribution.</span> <span class="n">latent_vector</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">gaugan</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">gaugan</span><span class="o">.</span><span class="n">latent_dim</span><span class="p">),</span> <span class="n">mean</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">stddev</span><span class="o">=</span><span class="mf">2.0</span> <span class="p">)</span> <span class="c1"># Generate fake images.</span> <span class="n">fake_images</span> <span class="o">=</span> <span class="n">gaugan</span><span class="o">.</span><span class="n">predict</span><span class="p">([</span><span class="n">latent_vector</span><span class="p">,</span> <span class="n">val_images</span><span class="p">[</span><span class="mi">2</span><span class="p">]])</span> <span class="n">real_images</span> <span class="o">=</span> <span class="n">val_images</span> <span class="n">grid_row</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">fake_images</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="mi">3</span><span class="p">)</span> <span class="n">grid_col</span> <span class="o">=</span> <span class="mi">3</span> <span class="n">f</span><span class="p">,</span> <span class="n">axarr</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">(</span><span class="n">grid_row</span><span class="p">,</span> <span class="n">grid_col</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="n">grid_col</span> <span class="o">*</span> <span class="mi">6</span><span class="p">,</span> <span class="n">grid_row</span> <span class="o">*</span> <span class="mi">6</span><span class="p">))</span> <span class="k">for</span> <span class="n">row</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">grid_row</span><span class="p">):</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">axarr</span> <span class="k">if</span> <span class="n">grid_row</span> <span class="o">==</span> <span class="mi">1</span> <span class="k">else</span> <span class="n">axarr</span><span class="p">[</span><span class="n">row</span><span class="p">]</span> <span class="n">ax</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">imshow</span><span class="p">((</span><span class="n">real_images</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="n">row</span><span class="p">]</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span><span class="p">)</span> <span class="n">ax</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">&quot;off&quot;</span><span class="p">)</span> <span class="n">ax</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">&quot;Mask&quot;</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">20</span><span class="p">)</span> <span class="n">ax</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">imshow</span><span class="p">((</span><span class="n">real_images</span><span class="p">[</span><span class="mi">1</span><span class="p">][</span><span class="n">row</span><span class="p">]</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span><span class="p">)</span> <span class="n">ax</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">&quot;off&quot;</span><span class="p">)</span> <span class="n">ax</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">&quot;Ground Truth&quot;</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">20</span><span class="p">)</span> <span class="n">ax</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">imshow</span><span class="p">((</span><span class="n">fake_images</span><span class="p">[</span><span class="n">row</span><span class="p">]</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span><span class="p">)</span> <span class="n">ax</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">&quot;off&quot;</span><span class="p">)</span> <span class="n">ax</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">&quot;Generated&quot;</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">20</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step </code></pre></div> </div> <p><img alt="png" src="/img/examples/generative/gaugan/gaugan_33_1.png" /></p> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 25ms/step </code></pre></div> </div> <p><img alt="png" src="/img/examples/generative/gaugan/gaugan_33_3.png" /></p> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 25ms/step </code></pre></div> </div> <p><img alt="png" src="/img/examples/generative/gaugan/gaugan_33_5.png" /></p> <hr /> <h2 id="final-words">Final words</h2> <ul> <li>The dataset we used in this example is a small one. For obtaining even better results we recommend to use a bigger dataset. GauGAN results were demonstrated with the <a href="https://github.com/nightrome/cocostuff">COCO-Stuff</a> and <a href="https://www.cityscapes-dataset.com/">CityScapes</a> datasets.</li> <li>This example was inspired the Chapter 6 of <a href="https://www.packtpub.com/product/hands-on-image-generation-with-tensorflow/9781838826789">Hands-On Image Generation with TensorFlow</a> by <a href="https://www.linkedin.com/in/soonyau/">Soon-Yau Cheong</a> and <a href="https://towardsdatascience.com/implementing-spade-using-fastai-6ad86b94030a">Implementing SPADE using fastai</a> by <a href="https://medium.com/@divyanshj.16">Divyansh Jha</a>.</li> <li>If you found this example interesting and exciting, you might want to check out <a href="https://github.com/soumik12345/tf2_gans">our repository</a> which we are currently building. It will include reimplementations of popular GANs and pretrained models. Our focus will be on readability and making the code as accessible as possible. Our plain is to first train our implementation of GauGAN (following the code of this example) on a bigger dataset and then make the repository public. We welcome contributions!</li> <li>Recently GauGAN2 was also released. You can check it out <a href="https://blogs.nvidia.com/blog/2021/11/22/gaugan2-ai-art-demo/">here</a>.</li> </ul> <p>Example available on HuggingFace.</p> <table> <thead> <tr> <th style="text-align: center;">Trained Model</th> <th style="text-align: center;">Demo</th> </tr> </thead> <tbody> <tr> <td style="text-align: center;"><a href="https://huggingface.co/keras-io/GauGAN-Image-generation"><img alt="Generic badge" src="https://img.shields.io/badge/%F0%9F%A4%97%20Model-GauGAN%20Image%20Generation-black.svg" /></a></td> <td style="text-align: center;"><a href="https://huggingface.co/spaces/keras-io/GauGAN_Conditional_Image_Generation"><img alt="Generic badge" src="https://img.shields.io/badge/%F0%9F%A4%97%20Spaces-GauGAN%20Image%20Generation-black.svg" /></a></td> </tr> </tbody> </table> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#gaugan-for-conditional-image-generation'>GauGAN for conditional image generation</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#data-collection'>Data collection</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#imports'>Imports</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#data-splitting'>Data splitting</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#data-loader'>Data loader</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#custom-layers'>Custom layers</a> </div> <div class='k-outline-depth-3'> <a href='#some-more-notes-on-spade'>Some more notes on SPADE</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#loss-functions'>Loss functions</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#gan-monitor-callback'>GAN monitor callback</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#subclassed-gaugan-model'>Subclassed GauGAN model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#gaugan-training'>GauGAN training</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#inference'>Inference</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#final-words'>Final words</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>

Pages: 1 2 3 4 5 6 7 8 9 10