CINXE.COM
Neural Style Transfer with AdaIN
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/generative/adain/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Neural Style Transfer with AdaIN"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Neural Style Transfer with AdaIN"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Neural Style Transfer with AdaIN</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink active" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink2" href="/examples/generative/ddim/">Denoising Diffusion Implicit Models</a> <a class="nav-sublink2" href="/examples/generative/random_walks_with_stable_diffusion/">A walk through latent space with Stable Diffusion</a> <a class="nav-sublink2" href="/examples/generative/dreambooth/">DreamBooth</a> <a class="nav-sublink2" href="/examples/generative/ddpm/">Denoising Diffusion Probabilistic Models</a> <a class="nav-sublink2" href="/examples/generative/fine_tune_via_textual_inversion/">Teach StableDiffusion new concepts via Textual Inversion</a> <a class="nav-sublink2" href="/examples/generative/finetune_stable_diffusion/">Fine-tuning Stable Diffusion</a> <a class="nav-sublink2" href="/examples/generative/vae/">Variational AutoEncoder</a> <a class="nav-sublink2" href="/examples/generative/dcgan_overriding_train_step/">GAN overriding Model.train_step</a> <a class="nav-sublink2" href="/examples/generative/wgan_gp/">WGAN-GP overriding Model.train_step</a> <a class="nav-sublink2" href="/examples/generative/conditional_gan/">Conditional GAN</a> <a class="nav-sublink2" href="/examples/generative/cyclegan/">CycleGAN</a> <a class="nav-sublink2" href="/examples/generative/gan_ada/">Data-efficient GANs with Adaptive Discriminator Augmentation</a> <a class="nav-sublink2" href="/examples/generative/deep_dream/">Deep Dream</a> <a class="nav-sublink2" href="/examples/generative/gaugan/">GauGAN for conditional image generation</a> <a class="nav-sublink2" href="/examples/generative/pixelcnn/">PixelCNN</a> <a class="nav-sublink2" href="/examples/generative/stylegan/">Face image generation with StyleGAN</a> <a class="nav-sublink2" href="/examples/generative/vq_vae/">Vector-Quantized Variational Autoencoders</a> <a class="nav-sublink2" href="/examples/generative/neural_style_transfer/">Neural style transfer</a> <a class="nav-sublink2 active" href="/examples/generative/adain/">Neural Style Transfer with AdaIN</a> <a class="nav-sublink2" href="/examples/generative/gpt2_text_generation_with_keras_hub/">GPT2 Text Generation with KerasHub</a> <a class="nav-sublink2" href="/examples/generative/text_generation_gpt/">GPT text generation from scratch with KerasHub</a> <a class="nav-sublink2" href="/examples/generative/text_generation_with_miniature_gpt/">Text generation with a miniature GPT</a> <a class="nav-sublink2" href="/examples/generative/lstm_character_level_text_generation/">Character-level text generation with LSTM</a> <a class="nav-sublink2" href="/examples/generative/text_generation_fnet/">Text Generation using FNet</a> <a class="nav-sublink2" href="/examples/generative/molecule_generation/">Drug Molecule Generation with VAE</a> <a class="nav-sublink2" href="/examples/generative/wgan-graphs/">WGAN-GP with R-GCN for the generation of small molecular graphs</a> <a class="nav-sublink2" href="/examples/generative/random_walks_with_stable_diffusion_3/">A walk through latent space with Stable Diffusion 3</a> <a class="nav-sublink2" href="/examples/generative/real_nvp/">Density estimation using Real NVP</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparameter Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> <a class="nav-link" href="/keras_cv/" role="tab" aria-selected="">KerasCV: Computer Vision Workflows</a> <a class="nav-link" href="/keras_nlp/" role="tab" aria-selected="">KerasNLP: Natural Language Workflows</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/generative/'>Generative Deep Learning</a> / Neural Style Transfer with AdaIN </div> <div class='k-content'> <h1 id="neural-style-transfer-with-adain">Neural Style Transfer with AdaIN</h1> <p><strong>Author:</strong> <a href="https://twitter.com/arig23498">Aritra Roy Gosthipaty</a>, <a href="https://twitter.com/ritwik_raha">Ritwik Raha</a><br> <strong>Date created:</strong> 2021/11/08<br> <strong>Last modified:</strong> 2021/11/08<br></p> <div class='example_version_banner keras_2'>ⓘ This example uses Keras 2</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/generative/ipynb/adain.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/examples/generative/adain.py"><strong>GitHub source</strong></a></p> <p><strong>Description:</strong> Neural Style Transfer with Adaptive Instance Normalization.</p> <h1 id="introduction">Introduction</h1> <p><a href="https://www.tensorflow.org/tutorials/generative/style_transfer">Neural Style Transfer</a> is the process of transferring the style of one image onto the content of another. This was first introduced in the seminal paper <a href="https://arxiv.org/abs/1508.06576">"A Neural Algorithm of Artistic Style"</a> by Gatys et al. A major limitation of the technique proposed in this work is in its runtime, as the algorithm uses a slow iterative optimization process.</p> <p>Follow-up papers that introduced <a href="https://arxiv.org/abs/1502.03167">Batch Normalization</a>, <a href="https://arxiv.org/abs/1701.02096">Instance Normalization</a> and <a href="https://arxiv.org/abs/1610.07629">Conditional Instance Normalization</a> allowed Style Transfer to be performed in new ways, no longer requiring a slow iterative process.</p> <p>Following these papers, the authors Xun Huang and Serge Belongie propose <a href="https://arxiv.org/abs/1703.06868">Adaptive Instance Normalization</a> (AdaIN), which allows arbitrary style transfer in real time.</p> <p>In this example we implement Adaptive Instance Normalization for Neural Style Transfer. We show in the below figure the output of our AdaIN model trained for only <strong>30 epochs</strong>.</p> <p><img alt="Style transfer sample gallery" src="https://i.imgur.com/zDjDuea.png" /></p> <p>You can also try out the model with your own images with this <a href="https://huggingface.co/spaces/ariG23498/nst">Hugging Face demo</a>.</p> <h1 id="setup">Setup</h1> <p>We begin with importing the necessary packages. We also set the seed for reproducibility. The global variables are hyperparameters which we can change as we like.</p> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">os</span> <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> <span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="nn">tf</span> <span class="kn">from</span> <span class="nn">tensorflow</span> <span class="kn">import</span> <span class="n">keras</span> <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span> <span class="kn">import</span> <span class="nn">tensorflow_datasets</span> <span class="k">as</span> <span class="nn">tfds</span> <span class="kn">from</span> <span class="nn">tensorflow.keras</span> <span class="kn">import</span> <span class="n">layers</span> <span class="c1"># Defining the global variables.</span> <span class="n">IMAGE_SIZE</span> <span class="o">=</span> <span class="p">(</span><span class="mi">224</span><span class="p">,</span> <span class="mi">224</span><span class="p">)</span> <span class="n">BATCH_SIZE</span> <span class="o">=</span> <span class="mi">64</span> <span class="c1"># Training for single epoch for time constraint.</span> <span class="c1"># Please use atleast 30 epochs to see good results.</span> <span class="n">EPOCHS</span> <span class="o">=</span> <span class="mi">1</span> <span class="n">AUTOTUNE</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">AUTOTUNE</span> </code></pre></div> <hr /> <h2 id="style-transfer-sample-gallery">Style transfer sample gallery</h2> <p>For Neural Style Transfer we need style images and content images. In this example we will use the <a href="https://www.kaggle.com/ikarus777/best-artworks-of-all-time">Best Artworks of All Time</a> as our style dataset and <a href="https://www.tensorflow.org/datasets/catalog/voc">Pascal VOC</a> as our content dataset.</p> <p>This is a deviation from the original paper implementation by the authors, where they use <a href="https://paperswithcode.com/dataset/wikiart">WIKI-Art</a> as style and <a href="https://cocodataset.org/#home">MSCOCO</a> as content datasets respectively. We do this to create a minimal yet reproducible example.</p> <hr /> <h2 id="downloading-the-dataset-from-kaggle">Downloading the dataset from Kaggle</h2> <p>The <a href="https://www.kaggle.com/ikarus777/best-artworks-of-all-time">Best Artworks of All Time</a> dataset is hosted on Kaggle and one can easily download it in Colab by following these steps:</p> <ul> <li>Follow the instructions <a href="https://github.com/Kaggle/kaggle-api">here</a> in order to obtain your Kaggle API keys in case you don't have them.</li> <li>Use the following command to upload the Kaggle API keys.</li> </ul> <div class="codehilite"><pre><span></span><code><span class="kn">from</span> <span class="nn">google.colab</span> <span class="kn">import</span> <span class="n">files</span> <span class="n">files</span><span class="o">.</span><span class="n">upload</span><span class="p">()</span> </code></pre></div> <ul> <li>Use the following commands to move the API keys to the proper directory and download the dataset.</li> </ul> <div class="codehilite"><pre><span></span><code>$<span class="w"> </span>mkdir<span class="w"> </span>~/.kaggle $<span class="w"> </span>cp<span class="w"> </span>kaggle.json<span class="w"> </span>~/.kaggle/ $<span class="w"> </span>chmod<span class="w"> </span><span class="m">600</span><span class="w"> </span>~/.kaggle/kaggle.json $<span class="w"> </span>kaggle<span class="w"> </span>datasets<span class="w"> </span>download<span class="w"> </span>ikarus777/best-artworks-of-all-time $<span class="w"> </span>unzip<span class="w"> </span>-qq<span class="w"> </span>best-artworks-of-all-time.zip $<span class="w"> </span>rm<span class="w"> </span>-rf<span class="w"> </span>images $<span class="w"> </span>mv<span class="w"> </span>resized<span class="w"> </span>artwork $<span class="w"> </span>rm<span class="w"> </span>best-artworks-of-all-time.zip<span class="w"> </span>artists.csv </code></pre></div> <hr /> <h2 id="tfdata"><a href="https://www.tensorflow.org/api_docs/python/tf/data"><code>tf.data</code></a> pipeline</h2> <p>In this section, we will build the <a href="https://www.tensorflow.org/api_docs/python/tf/data"><code>tf.data</code></a> pipeline for the project. For the style dataset, we decode, convert and resize the images from the folder. For the content images we are already presented with a <a href="https://www.tensorflow.org/api_docs/python/tf/data"><code>tf.data</code></a> dataset as we use the <code>tfds</code> module.</p> <p>After we have our style and content data pipeline ready, we zip the two together to obtain the data pipeline that our model will consume.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">decode_and_resize</span><span class="p">(</span><span class="n">image_path</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Decodes and resizes an image from the image file path.</span> <span class="sd"> Args:</span> <span class="sd"> image_path: The image file path.</span> <span class="sd"> Returns:</span> <span class="sd"> A resized image.</span> <span class="sd"> """</span> <span class="n">image</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">io</span><span class="o">.</span><span class="n">read_file</span><span class="p">(</span><span class="n">image_path</span><span class="p">)</span> <span class="n">image</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">decode_jpeg</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="n">channels</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span> <span class="n">image</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">convert_image_dtype</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"float32"</span><span class="p">)</span> <span class="n">image</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">resize</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="n">IMAGE_SIZE</span><span class="p">)</span> <span class="k">return</span> <span class="n">image</span> <span class="k">def</span> <span class="nf">extract_image_from_voc</span><span class="p">(</span><span class="n">element</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Extracts image from the PascalVOC dataset.</span> <span class="sd"> Args:</span> <span class="sd"> element: A dictionary of data.</span> <span class="sd"> Returns:</span> <span class="sd"> A resized image.</span> <span class="sd"> """</span> <span class="n">image</span> <span class="o">=</span> <span class="n">element</span><span class="p">[</span><span class="s2">"image"</span><span class="p">]</span> <span class="n">image</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">convert_image_dtype</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"float32"</span><span class="p">)</span> <span class="n">image</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">resize</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="n">IMAGE_SIZE</span><span class="p">)</span> <span class="k">return</span> <span class="n">image</span> <span class="c1"># Get the image file paths for the style images.</span> <span class="n">style_images</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">listdir</span><span class="p">(</span><span class="s2">"/content/artwork/resized"</span><span class="p">)</span> <span class="n">style_images</span> <span class="o">=</span> <span class="p">[</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="s2">"/content/artwork/resized"</span><span class="p">,</span> <span class="n">path</span><span class="p">)</span> <span class="k">for</span> <span class="n">path</span> <span class="ow">in</span> <span class="n">style_images</span><span class="p">]</span> <span class="c1"># split the style images in train, val and test</span> <span class="n">total_style_images</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">style_images</span><span class="p">)</span> <span class="n">train_style</span> <span class="o">=</span> <span class="n">style_images</span><span class="p">[:</span> <span class="nb">int</span><span class="p">(</span><span class="mf">0.8</span> <span class="o">*</span> <span class="n">total_style_images</span><span class="p">)]</span> <span class="n">val_style</span> <span class="o">=</span> <span class="n">style_images</span><span class="p">[</span><span class="nb">int</span><span class="p">(</span><span class="mf">0.8</span> <span class="o">*</span> <span class="n">total_style_images</span><span class="p">)</span> <span class="p">:</span> <span class="nb">int</span><span class="p">(</span><span class="mf">0.9</span> <span class="o">*</span> <span class="n">total_style_images</span><span class="p">)]</span> <span class="n">test_style</span> <span class="o">=</span> <span class="n">style_images</span><span class="p">[</span><span class="nb">int</span><span class="p">(</span><span class="mf">0.9</span> <span class="o">*</span> <span class="n">total_style_images</span><span class="p">)</span> <span class="p">:]</span> <span class="c1"># Build the style and content tf.data datasets.</span> <span class="n">train_style_ds</span> <span class="o">=</span> <span class="p">(</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">from_tensor_slices</span><span class="p">(</span><span class="n">train_style</span><span class="p">)</span> <span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">decode_and_resize</span><span class="p">,</span> <span class="n">num_parallel_calls</span><span class="o">=</span><span class="n">AUTOTUNE</span><span class="p">)</span> <span class="o">.</span><span class="n">repeat</span><span class="p">()</span> <span class="p">)</span> <span class="n">train_content_ds</span> <span class="o">=</span> <span class="n">tfds</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s2">"voc"</span><span class="p">,</span> <span class="n">split</span><span class="o">=</span><span class="s2">"train"</span><span class="p">)</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">extract_image_from_voc</span><span class="p">)</span><span class="o">.</span><span class="n">repeat</span><span class="p">()</span> <span class="n">val_style_ds</span> <span class="o">=</span> <span class="p">(</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">from_tensor_slices</span><span class="p">(</span><span class="n">val_style</span><span class="p">)</span> <span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">decode_and_resize</span><span class="p">,</span> <span class="n">num_parallel_calls</span><span class="o">=</span><span class="n">AUTOTUNE</span><span class="p">)</span> <span class="o">.</span><span class="n">repeat</span><span class="p">()</span> <span class="p">)</span> <span class="n">val_content_ds</span> <span class="o">=</span> <span class="p">(</span> <span class="n">tfds</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s2">"voc"</span><span class="p">,</span> <span class="n">split</span><span class="o">=</span><span class="s2">"validation"</span><span class="p">)</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">extract_image_from_voc</span><span class="p">)</span><span class="o">.</span><span class="n">repeat</span><span class="p">()</span> <span class="p">)</span> <span class="n">test_style_ds</span> <span class="o">=</span> <span class="p">(</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">from_tensor_slices</span><span class="p">(</span><span class="n">test_style</span><span class="p">)</span> <span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">decode_and_resize</span><span class="p">,</span> <span class="n">num_parallel_calls</span><span class="o">=</span><span class="n">AUTOTUNE</span><span class="p">)</span> <span class="o">.</span><span class="n">repeat</span><span class="p">()</span> <span class="p">)</span> <span class="n">test_content_ds</span> <span class="o">=</span> <span class="p">(</span> <span class="n">tfds</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s2">"voc"</span><span class="p">,</span> <span class="n">split</span><span class="o">=</span><span class="s2">"test"</span><span class="p">)</span> <span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">extract_image_from_voc</span><span class="p">,</span> <span class="n">num_parallel_calls</span><span class="o">=</span><span class="n">AUTOTUNE</span><span class="p">)</span> <span class="o">.</span><span class="n">repeat</span><span class="p">()</span> <span class="p">)</span> <span class="c1"># Zipping the style and content datasets.</span> <span class="n">train_ds</span> <span class="o">=</span> <span class="p">(</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">zip</span><span class="p">((</span><span class="n">train_style_ds</span><span class="p">,</span> <span class="n">train_content_ds</span><span class="p">))</span> <span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="n">BATCH_SIZE</span> <span class="o">*</span> <span class="mi">2</span><span class="p">)</span> <span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="n">BATCH_SIZE</span><span class="p">)</span> <span class="o">.</span><span class="n">prefetch</span><span class="p">(</span><span class="n">AUTOTUNE</span><span class="p">)</span> <span class="p">)</span> <span class="n">val_ds</span> <span class="o">=</span> <span class="p">(</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">zip</span><span class="p">((</span><span class="n">val_style_ds</span><span class="p">,</span> <span class="n">val_content_ds</span><span class="p">))</span> <span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="n">BATCH_SIZE</span> <span class="o">*</span> <span class="mi">2</span><span class="p">)</span> <span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="n">BATCH_SIZE</span><span class="p">)</span> <span class="o">.</span><span class="n">prefetch</span><span class="p">(</span><span class="n">AUTOTUNE</span><span class="p">)</span> <span class="p">)</span> <span class="n">test_ds</span> <span class="o">=</span> <span class="p">(</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">zip</span><span class="p">((</span><span class="n">test_style_ds</span><span class="p">,</span> <span class="n">test_content_ds</span><span class="p">))</span> <span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="n">BATCH_SIZE</span> <span class="o">*</span> <span class="mi">2</span><span class="p">)</span> <span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="n">BATCH_SIZE</span><span class="p">)</span> <span class="o">.</span><span class="n">prefetch</span><span class="p">(</span><span class="n">AUTOTUNE</span><span class="p">)</span> <span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>[1mDownloading and preparing dataset voc/2007/4.0.0 (download: 868.85 MiB, generated: Unknown size, total: 868.85 MiB) to /root/tensorflow_datasets/voc/2007/4.0.0...[0m Dl Completed...: 0 url [00:00, ? url/s] Dl Size...: 0 MiB [00:00, ? MiB/s] Extraction completed...: 0 file [00:00, ? file/s] </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>0 examples [00:00, ? examples/s] Shuffling and writing examples to /root/tensorflow_datasets/voc/2007/4.0.0.incompleteP16YU5/voc-test.tfrecord 0%| | 0/4952 [00:00<?, ? examples/s] 0 examples [00:00, ? examples/s] Shuffling and writing examples to /root/tensorflow_datasets/voc/2007/4.0.0.incompleteP16YU5/voc-train.tfrecord 0%| | 0/2501 [00:00<?, ? examples/s] 0 examples [00:00, ? examples/s] Shuffling and writing examples to /root/tensorflow_datasets/voc/2007/4.0.0.incompleteP16YU5/voc-validation.tfrecord 0%| | 0/2510 [00:00<?, ? examples/s] [1mDataset voc downloaded and prepared to /root/tensorflow_datasets/voc/2007/4.0.0. Subsequent calls will reuse this data.[0m </code></pre></div> </div> <hr /> <h2 id="visualizing-the-data">Visualizing the data</h2> <p>It is always better to visualize the data before training. To ensure the correctness of our preprocessing pipeline, we visualize 10 samples from our dataset.</p> <div class="codehilite"><pre><span></span><code><span class="n">style</span><span class="p">,</span> <span class="n">content</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="nb">iter</span><span class="p">(</span><span class="n">train_ds</span><span class="p">))</span> <span class="n">fig</span><span class="p">,</span> <span class="n">axes</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">(</span><span class="n">nrows</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">ncols</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">30</span><span class="p">))</span> <span class="p">[</span><span class="n">ax</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">"off"</span><span class="p">)</span> <span class="k">for</span> <span class="n">ax</span> <span class="ow">in</span> <span class="n">np</span><span class="o">.</span><span class="n">ravel</span><span class="p">(</span><span class="n">axes</span><span class="p">)]</span> <span class="k">for</span> <span class="p">(</span><span class="n">axis</span><span class="p">,</span> <span class="n">style_image</span><span class="p">,</span> <span class="n">content_image</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">axes</span><span class="p">,</span> <span class="n">style</span><span class="p">[</span><span class="mi">0</span><span class="p">:</span><span class="mi">10</span><span class="p">],</span> <span class="n">content</span><span class="p">[</span><span class="mi">0</span><span class="p">:</span><span class="mi">10</span><span class="p">]):</span> <span class="p">(</span><span class="n">ax_style</span><span class="p">,</span> <span class="n">ax_content</span><span class="p">)</span> <span class="o">=</span> <span class="n">axis</span> <span class="n">ax_style</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">style_image</span><span class="p">)</span> <span class="n">ax_style</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">"Style Image"</span><span class="p">)</span> <span class="n">ax_content</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">content_image</span><span class="p">)</span> <span class="n">ax_content</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">"Content Image"</span><span class="p">)</span> </code></pre></div> <p><img alt="png" src="/img/examples/generative/adain/adain_8_0.png" /></p> <hr /> <h2 id="architecture">Architecture</h2> <p>The style transfer network takes a content image and a style image as inputs and outputs the style transferred image. The authors of AdaIN propose a simple encoder-decoder structure for achieving this.</p> <p><img alt="AdaIN architecture" src="https://i.imgur.com/JbIfoyE.png" /></p> <p>The content image (<code>C</code>) and the style image (<code>S</code>) are both fed to the encoder networks. The output from these encoder networks (feature maps) are then fed to the AdaIN layer. The AdaIN layer computes a combined feature map. This feature map is then fed into a randomly initialized decoder network that serves as the generator for the neural style transferred image.</p> <p><img alt="AdaIn equation" src="https://i.imgur.com/hqhcBQS.png" /></p> <p>The style feature map (<code>fs</code>) and the content feature map (<code>fc</code>) are fed to the AdaIN layer. This layer produced the combined feature map <code>t</code>. The function <code>g</code> represents the decoder (generator) network.</p> <h3 id="encoder">Encoder</h3> <p>The encoder is a part of the pretrained (pretrained on <a href="https://www.image-net.org/">imagenet</a>) VGG19 model. We slice the model from the <code>block4-conv1</code> layer. The output layer is as suggested by the authors in their paper.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">get_encoder</span><span class="p">():</span> <span class="n">vgg19</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">applications</span><span class="o">.</span><span class="n">VGG19</span><span class="p">(</span> <span class="n">include_top</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">weights</span><span class="o">=</span><span class="s2">"imagenet"</span><span class="p">,</span> <span class="n">input_shape</span><span class="o">=</span><span class="p">(</span><span class="o">*</span><span class="n">IMAGE_SIZE</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="p">)</span> <span class="n">vgg19</span><span class="o">.</span><span class="n">trainable</span> <span class="o">=</span> <span class="kc">False</span> <span class="n">mini_vgg19</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">vgg19</span><span class="o">.</span><span class="n">input</span><span class="p">,</span> <span class="n">vgg19</span><span class="o">.</span><span class="n">get_layer</span><span class="p">(</span><span class="s2">"block4_conv1"</span><span class="p">)</span><span class="o">.</span><span class="n">output</span><span class="p">)</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Input</span><span class="p">([</span><span class="o">*</span><span class="n">IMAGE_SIZE</span><span class="p">,</span> <span class="mi">3</span><span class="p">])</span> <span class="n">mini_vgg19_out</span> <span class="o">=</span> <span class="n">mini_vgg19</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="k">return</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">mini_vgg19_out</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"mini_vgg19"</span><span class="p">)</span> </code></pre></div> <h3 id="adaptive-instance-normalization">Adaptive Instance Normalization</h3> <p>The AdaIN layer takes in the features of the content and style image. The layer can be defined via the following equation:</p> <p><img alt="AdaIn formula" src="https://i.imgur.com/tWq3VKP.png" /></p> <p>where <code>sigma</code> is the standard deviation and <code>mu</code> is the mean for the concerned variable. In the above equation the mean and variance of the content feature map <code>fc</code> is aligned with the mean and variance of the style feature maps <code>fs</code>.</p> <p>It is important to note that the AdaIN layer proposed by the authors uses no other parameters apart from mean and variance. The layer also does not have any trainable parameters. This is why we use a <em>Python function</em> instead of using a <em>Keras layer</em>. The function takes style and content feature maps, computes the mean and standard deviation of the images and returns the adaptive instance normalized feature map.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">get_mean_std</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">epsilon</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">):</span> <span class="n">axes</span> <span class="o">=</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">]</span> <span class="c1"># Compute the mean and standard deviation of a tensor.</span> <span class="n">mean</span><span class="p">,</span> <span class="n">variance</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">moments</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axes</span><span class="o">=</span><span class="n">axes</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="n">standard_deviation</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">variance</span> <span class="o">+</span> <span class="n">epsilon</span><span class="p">)</span> <span class="k">return</span> <span class="n">mean</span><span class="p">,</span> <span class="n">standard_deviation</span> <span class="k">def</span> <span class="nf">ada_in</span><span class="p">(</span><span class="n">style</span><span class="p">,</span> <span class="n">content</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Computes the AdaIn feature map.</span> <span class="sd"> Args:</span> <span class="sd"> style: The style feature map.</span> <span class="sd"> content: The content feature map.</span> <span class="sd"> Returns:</span> <span class="sd"> The AdaIN feature map.</span> <span class="sd"> """</span> <span class="n">content_mean</span><span class="p">,</span> <span class="n">content_std</span> <span class="o">=</span> <span class="n">get_mean_std</span><span class="p">(</span><span class="n">content</span><span class="p">)</span> <span class="n">style_mean</span><span class="p">,</span> <span class="n">style_std</span> <span class="o">=</span> <span class="n">get_mean_std</span><span class="p">(</span><span class="n">style</span><span class="p">)</span> <span class="n">t</span> <span class="o">=</span> <span class="n">style_std</span> <span class="o">*</span> <span class="p">(</span><span class="n">content</span> <span class="o">-</span> <span class="n">content_mean</span><span class="p">)</span> <span class="o">/</span> <span class="n">content_std</span> <span class="o">+</span> <span class="n">style_mean</span> <span class="k">return</span> <span class="n">t</span> </code></pre></div> <h3 id="decoder">Decoder</h3> <p>The authors specify that the decoder network must mirror the encoder network. We have symmetrically inverted the encoder to build our decoder. We have used <code>UpSampling2D</code> layers to increase the spatial resolution of the feature maps.</p> <p>Note that the authors warn against using any normalization layer in the decoder network, and do indeed go on to show that including batch normalization or instance normalization hurts the performance of the overall network.</p> <p>This is the only portion of the entire architecture that is trainable.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">get_decoder</span><span class="p">():</span> <span class="n">config</span> <span class="o">=</span> <span class="p">{</span><span class="s2">"kernel_size"</span><span class="p">:</span> <span class="mi">3</span><span class="p">,</span> <span class="s2">"strides"</span><span class="p">:</span> <span class="mi">1</span><span class="p">,</span> <span class="s2">"padding"</span><span class="p">:</span> <span class="s2">"same"</span><span class="p">,</span> <span class="s2">"activation"</span><span class="p">:</span> <span class="s2">"relu"</span><span class="p">}</span> <span class="n">decoder</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span> <span class="p">[</span> <span class="n">layers</span><span class="o">.</span><span class="n">InputLayer</span><span class="p">((</span><span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="mi">512</span><span class="p">)),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="n">filters</span><span class="o">=</span><span class="mi">512</span><span class="p">,</span> <span class="o">**</span><span class="n">config</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">UpSampling2D</span><span class="p">(),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="n">filters</span><span class="o">=</span><span class="mi">256</span><span class="p">,</span> <span class="o">**</span><span class="n">config</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="n">filters</span><span class="o">=</span><span class="mi">256</span><span class="p">,</span> <span class="o">**</span><span class="n">config</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="n">filters</span><span class="o">=</span><span class="mi">256</span><span class="p">,</span> <span class="o">**</span><span class="n">config</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="n">filters</span><span class="o">=</span><span class="mi">256</span><span class="p">,</span> <span class="o">**</span><span class="n">config</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">UpSampling2D</span><span class="p">(),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="n">filters</span><span class="o">=</span><span class="mi">128</span><span class="p">,</span> <span class="o">**</span><span class="n">config</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="n">filters</span><span class="o">=</span><span class="mi">128</span><span class="p">,</span> <span class="o">**</span><span class="n">config</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">UpSampling2D</span><span class="p">(),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="n">filters</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="o">**</span><span class="n">config</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span> <span class="n">filters</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">strides</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"same"</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"sigmoid"</span><span class="p">,</span> <span class="p">),</span> <span class="p">]</span> <span class="p">)</span> <span class="k">return</span> <span class="n">decoder</span> </code></pre></div> <h3 id="loss-functions">Loss functions</h3> <p>Here we build the loss functions for the neural style transfer model. The authors propose to use a pretrained VGG-19 to compute the loss function of the network. It is important to keep in mind that this will be used for training only the decoder network. The total loss (<code>Lt</code>) is a weighted combination of content loss (<code>Lc</code>) and style loss (<code>Ls</code>). The <code>lambda</code> term is used to vary the amount of style transferred.</p> <p><img alt="The total loss" src="https://i.imgur.com/Q5y1jUM.png" /></p> <h3 id="content-loss">Content Loss</h3> <p>This is the Euclidean distance between the content image features and the features of the neural style transferred image.</p> <p><img alt="The content loss" src="https://i.imgur.com/dZ0uD0N.png" /></p> <p>Here the authors propose to use the output from the AdaIn layer <code>t</code> as the content target rather than using features of the original image as target. This is done to speed up convergence.</p> <h3 id="style-loss">Style Loss</h3> <p>Rather than using the more commonly used <a href="https://mathworld.wolfram.com/GramMatrix.html">Gram Matrix</a>, the authors propose to compute the difference between the statistical features (mean and variance) which makes it conceptually cleaner. This can be easily visualized via the following equation:</p> <p><img alt="The style loss" src="https://i.imgur.com/Ctclhn3.png" /></p> <p>where <code>theta</code> denotes the layers in VGG-19 used to compute the loss. In this case this corresponds to:</p> <ul> <li><code>block1_conv1</code></li> <li><code>block1_conv2</code></li> <li><code>block1_conv3</code></li> <li><code>block1_conv4</code></li> </ul> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">get_loss_net</span><span class="p">():</span> <span class="n">vgg19</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">applications</span><span class="o">.</span><span class="n">VGG19</span><span class="p">(</span> <span class="n">include_top</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">weights</span><span class="o">=</span><span class="s2">"imagenet"</span><span class="p">,</span> <span class="n">input_shape</span><span class="o">=</span><span class="p">(</span><span class="o">*</span><span class="n">IMAGE_SIZE</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span> <span class="p">)</span> <span class="n">vgg19</span><span class="o">.</span><span class="n">trainable</span> <span class="o">=</span> <span class="kc">False</span> <span class="n">layer_names</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"block1_conv1"</span><span class="p">,</span> <span class="s2">"block2_conv1"</span><span class="p">,</span> <span class="s2">"block3_conv1"</span><span class="p">,</span> <span class="s2">"block4_conv1"</span><span class="p">]</span> <span class="n">outputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">vgg19</span><span class="o">.</span><span class="n">get_layer</span><span class="p">(</span><span class="n">name</span><span class="p">)</span><span class="o">.</span><span class="n">output</span> <span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">layer_names</span><span class="p">]</span> <span class="n">mini_vgg19</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">vgg19</span><span class="o">.</span><span class="n">input</span><span class="p">,</span> <span class="n">outputs</span><span class="p">)</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Input</span><span class="p">([</span><span class="o">*</span><span class="n">IMAGE_SIZE</span><span class="p">,</span> <span class="mi">3</span><span class="p">])</span> <span class="n">mini_vgg19_out</span> <span class="o">=</span> <span class="n">mini_vgg19</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="k">return</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">mini_vgg19_out</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"loss_net"</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="neural-style-transfer">Neural Style Transfer</h2> <p>This is the trainer module. We wrap the encoder and decoder inside a <a href="/api/models/model#model-class"><code>tf.keras.Model</code></a> subclass. This allows us to customize what happens in the <code>model.fit()</code> loop.</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">NeuralStyleTransfer</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">encoder</span><span class="p">,</span> <span class="n">decoder</span><span class="p">,</span> <span class="n">loss_net</span><span class="p">,</span> <span class="n">style_weight</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span> <span class="o">=</span> <span class="n">encoder</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span> <span class="o">=</span> <span class="n">decoder</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_net</span> <span class="o">=</span> <span class="n">loss_net</span> <span class="bp">self</span><span class="o">.</span><span class="n">style_weight</span> <span class="o">=</span> <span class="n">style_weight</span> <span class="k">def</span> <span class="nf">compile</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">,</span> <span class="n">loss_fn</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">compile</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span> <span class="o">=</span> <span class="n">optimizer</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_fn</span> <span class="o">=</span> <span class="n">loss_fn</span> <span class="bp">self</span><span class="o">.</span><span class="n">style_loss_tracker</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">Mean</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">"style_loss"</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">content_loss_tracker</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">Mean</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">"content_loss"</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">total_loss_tracker</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">Mean</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">"total_loss"</span><span class="p">)</span> <span class="k">def</span> <span class="nf">train_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="n">style</span><span class="p">,</span> <span class="n">content</span> <span class="o">=</span> <span class="n">inputs</span> <span class="c1"># Initialize the content and style loss.</span> <span class="n">loss_content</span> <span class="o">=</span> <span class="mf">0.0</span> <span class="n">loss_style</span> <span class="o">=</span> <span class="mf">0.0</span> <span class="k">with</span> <span class="n">tf</span><span class="o">.</span><span class="n">GradientTape</span><span class="p">()</span> <span class="k">as</span> <span class="n">tape</span><span class="p">:</span> <span class="c1"># Encode the style and content image.</span> <span class="n">style_encoded</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="p">(</span><span class="n">style</span><span class="p">)</span> <span class="n">content_encoded</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="p">(</span><span class="n">content</span><span class="p">)</span> <span class="c1"># Compute the AdaIN target feature maps.</span> <span class="n">t</span> <span class="o">=</span> <span class="n">ada_in</span><span class="p">(</span><span class="n">style</span><span class="o">=</span><span class="n">style_encoded</span><span class="p">,</span> <span class="n">content</span><span class="o">=</span><span class="n">content_encoded</span><span class="p">)</span> <span class="c1"># Generate the neural style transferred image.</span> <span class="n">reconstructed_image</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span><span class="p">(</span><span class="n">t</span><span class="p">)</span> <span class="c1"># Compute the losses.</span> <span class="n">reconstructed_vgg_features</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_net</span><span class="p">(</span><span class="n">reconstructed_image</span><span class="p">)</span> <span class="n">style_vgg_features</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_net</span><span class="p">(</span><span class="n">style</span><span class="p">)</span> <span class="n">loss_content</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_fn</span><span class="p">(</span><span class="n">t</span><span class="p">,</span> <span class="n">reconstructed_vgg_features</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span> <span class="k">for</span> <span class="n">inp</span><span class="p">,</span> <span class="n">out</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">style_vgg_features</span><span class="p">,</span> <span class="n">reconstructed_vgg_features</span><span class="p">):</span> <span class="n">mean_inp</span><span class="p">,</span> <span class="n">std_inp</span> <span class="o">=</span> <span class="n">get_mean_std</span><span class="p">(</span><span class="n">inp</span><span class="p">)</span> <span class="n">mean_out</span><span class="p">,</span> <span class="n">std_out</span> <span class="o">=</span> <span class="n">get_mean_std</span><span class="p">(</span><span class="n">out</span><span class="p">)</span> <span class="n">loss_style</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_fn</span><span class="p">(</span><span class="n">mean_inp</span><span class="p">,</span> <span class="n">mean_out</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_fn</span><span class="p">(</span> <span class="n">std_inp</span><span class="p">,</span> <span class="n">std_out</span> <span class="p">)</span> <span class="n">loss_style</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">style_weight</span> <span class="o">*</span> <span class="n">loss_style</span> <span class="n">total_loss</span> <span class="o">=</span> <span class="n">loss_content</span> <span class="o">+</span> <span class="n">loss_style</span> <span class="c1"># Compute gradients and optimize the decoder.</span> <span class="n">trainable_vars</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span><span class="o">.</span><span class="n">trainable_variables</span> <span class="n">gradients</span> <span class="o">=</span> <span class="n">tape</span><span class="o">.</span><span class="n">gradient</span><span class="p">(</span><span class="n">total_loss</span><span class="p">,</span> <span class="n">trainable_vars</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="n">gradients</span><span class="p">,</span> <span class="n">trainable_vars</span><span class="p">))</span> <span class="c1"># Update the trackers.</span> <span class="bp">self</span><span class="o">.</span><span class="n">style_loss_tracker</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">loss_style</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">content_loss_tracker</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">loss_content</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">total_loss_tracker</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">total_loss</span><span class="p">)</span> <span class="k">return</span> <span class="p">{</span> <span class="s2">"style_loss"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">style_loss_tracker</span><span class="o">.</span><span class="n">result</span><span class="p">(),</span> <span class="s2">"content_loss"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">content_loss_tracker</span><span class="o">.</span><span class="n">result</span><span class="p">(),</span> <span class="s2">"total_loss"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">total_loss_tracker</span><span class="o">.</span><span class="n">result</span><span class="p">(),</span> <span class="p">}</span> <span class="k">def</span> <span class="nf">test_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="n">style</span><span class="p">,</span> <span class="n">content</span> <span class="o">=</span> <span class="n">inputs</span> <span class="c1"># Initialize the content and style loss.</span> <span class="n">loss_content</span> <span class="o">=</span> <span class="mf">0.0</span> <span class="n">loss_style</span> <span class="o">=</span> <span class="mf">0.0</span> <span class="c1"># Encode the style and content image.</span> <span class="n">style_encoded</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="p">(</span><span class="n">style</span><span class="p">)</span> <span class="n">content_encoded</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="p">(</span><span class="n">content</span><span class="p">)</span> <span class="c1"># Compute the AdaIN target feature maps.</span> <span class="n">t</span> <span class="o">=</span> <span class="n">ada_in</span><span class="p">(</span><span class="n">style</span><span class="o">=</span><span class="n">style_encoded</span><span class="p">,</span> <span class="n">content</span><span class="o">=</span><span class="n">content_encoded</span><span class="p">)</span> <span class="c1"># Generate the neural style transferred image.</span> <span class="n">reconstructed_image</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">decoder</span><span class="p">(</span><span class="n">t</span><span class="p">)</span> <span class="c1"># Compute the losses.</span> <span class="n">recons_vgg_features</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_net</span><span class="p">(</span><span class="n">reconstructed_image</span><span class="p">)</span> <span class="n">style_vgg_features</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_net</span><span class="p">(</span><span class="n">style</span><span class="p">)</span> <span class="n">loss_content</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_fn</span><span class="p">(</span><span class="n">t</span><span class="p">,</span> <span class="n">recons_vgg_features</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span> <span class="k">for</span> <span class="n">inp</span><span class="p">,</span> <span class="n">out</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">style_vgg_features</span><span class="p">,</span> <span class="n">recons_vgg_features</span><span class="p">):</span> <span class="n">mean_inp</span><span class="p">,</span> <span class="n">std_inp</span> <span class="o">=</span> <span class="n">get_mean_std</span><span class="p">(</span><span class="n">inp</span><span class="p">)</span> <span class="n">mean_out</span><span class="p">,</span> <span class="n">std_out</span> <span class="o">=</span> <span class="n">get_mean_std</span><span class="p">(</span><span class="n">out</span><span class="p">)</span> <span class="n">loss_style</span> <span class="o">+=</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_fn</span><span class="p">(</span><span class="n">mean_inp</span><span class="p">,</span> <span class="n">mean_out</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_fn</span><span class="p">(</span> <span class="n">std_inp</span><span class="p">,</span> <span class="n">std_out</span> <span class="p">)</span> <span class="n">loss_style</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">style_weight</span> <span class="o">*</span> <span class="n">loss_style</span> <span class="n">total_loss</span> <span class="o">=</span> <span class="n">loss_content</span> <span class="o">+</span> <span class="n">loss_style</span> <span class="c1"># Update the trackers.</span> <span class="bp">self</span><span class="o">.</span><span class="n">style_loss_tracker</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">loss_style</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">content_loss_tracker</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">loss_content</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">total_loss_tracker</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">total_loss</span><span class="p">)</span> <span class="k">return</span> <span class="p">{</span> <span class="s2">"style_loss"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">style_loss_tracker</span><span class="o">.</span><span class="n">result</span><span class="p">(),</span> <span class="s2">"content_loss"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">content_loss_tracker</span><span class="o">.</span><span class="n">result</span><span class="p">(),</span> <span class="s2">"total_loss"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">total_loss_tracker</span><span class="o">.</span><span class="n">result</span><span class="p">(),</span> <span class="p">}</span> <span class="nd">@property</span> <span class="k">def</span> <span class="nf">metrics</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="k">return</span> <span class="p">[</span> <span class="bp">self</span><span class="o">.</span><span class="n">style_loss_tracker</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">content_loss_tracker</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">total_loss_tracker</span><span class="p">,</span> <span class="p">]</span> </code></pre></div> <hr /> <h2 id="train-monitor-callback">Train Monitor callback</h2> <p>This callback is used to visualize the style transfer output of the model at the end of each epoch. The objective of style transfer cannot be quantified properly, and is to be subjectively evaluated by an audience. For this reason, visualization is a key aspect of evaluating the model.</p> <div class="codehilite"><pre><span></span><code><span class="n">test_style</span><span class="p">,</span> <span class="n">test_content</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="nb">iter</span><span class="p">(</span><span class="n">test_ds</span><span class="p">))</span> <span class="k">class</span> <span class="nc">TrainMonitor</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">callbacks</span><span class="o">.</span><span class="n">Callback</span><span class="p">):</span> <span class="k">def</span> <span class="nf">on_epoch_end</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">epoch</span><span class="p">,</span> <span class="n">logs</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="c1"># Encode the style and content image.</span> <span class="n">test_style_encoded</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">encoder</span><span class="p">(</span><span class="n">test_style</span><span class="p">)</span> <span class="n">test_content_encoded</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">encoder</span><span class="p">(</span><span class="n">test_content</span><span class="p">)</span> <span class="c1"># Compute the AdaIN features.</span> <span class="n">test_t</span> <span class="o">=</span> <span class="n">ada_in</span><span class="p">(</span><span class="n">style</span><span class="o">=</span><span class="n">test_style_encoded</span><span class="p">,</span> <span class="n">content</span><span class="o">=</span><span class="n">test_content_encoded</span><span class="p">)</span> <span class="n">test_reconstructed_image</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">decoder</span><span class="p">(</span><span class="n">test_t</span><span class="p">)</span> <span class="c1"># Plot the Style, Content and the NST image.</span> <span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">(</span><span class="n">nrows</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">ncols</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">20</span><span class="p">,</span> <span class="mi">5</span><span class="p">))</span> <span class="n">ax</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">array_to_img</span><span class="p">(</span><span class="n">test_style</span><span class="p">[</span><span class="mi">0</span><span class="p">]))</span> <span class="n">ax</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Style: </span><span class="si">{</span><span class="n">epoch</span><span class="si">:</span><span class="s2">03d</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> <span class="n">ax</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">array_to_img</span><span class="p">(</span><span class="n">test_content</span><span class="p">[</span><span class="mi">0</span><span class="p">]))</span> <span class="n">ax</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Content: </span><span class="si">{</span><span class="n">epoch</span><span class="si">:</span><span class="s2">03d</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> <span class="n">ax</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">array_to_img</span><span class="p">(</span><span class="n">test_reconstructed_image</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="p">)</span> <span class="n">ax</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="sa">f</span><span class="s2">"NST: </span><span class="si">{</span><span class="n">epoch</span><span class="si">:</span><span class="s2">03d</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> <span class="n">plt</span><span class="o">.</span><span class="n">close</span><span class="p">()</span> </code></pre></div> <hr /> <h2 id="train-the-model">Train the model</h2> <p>In this section, we define the optimizer, the loss function, and the trainer module. We compile the trainer module with the optimizer and the loss function and then train it.</p> <p><em>Note</em>: We train the model for a single epoch for time constraints, but we will need to train is for atleast 30 epochs to see good results.</p> <div class="codehilite"><pre><span></span><code><span class="n">optimizer</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">learning_rate</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">)</span> <span class="n">loss_fn</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">MeanSquaredError</span><span class="p">()</span> <span class="n">encoder</span> <span class="o">=</span> <span class="n">get_encoder</span><span class="p">()</span> <span class="n">loss_net</span> <span class="o">=</span> <span class="n">get_loss_net</span><span class="p">()</span> <span class="n">decoder</span> <span class="o">=</span> <span class="n">get_decoder</span><span class="p">()</span> <span class="n">model</span> <span class="o">=</span> <span class="n">NeuralStyleTransfer</span><span class="p">(</span> <span class="n">encoder</span><span class="o">=</span><span class="n">encoder</span><span class="p">,</span> <span class="n">decoder</span><span class="o">=</span><span class="n">decoder</span><span class="p">,</span> <span class="n">loss_net</span><span class="o">=</span><span class="n">loss_net</span><span class="p">,</span> <span class="n">style_weight</span><span class="o">=</span><span class="mf">4.0</span> <span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">optimizer</span><span class="o">=</span><span class="n">optimizer</span><span class="p">,</span> <span class="n">loss_fn</span><span class="o">=</span><span class="n">loss_fn</span><span class="p">)</span> <span class="n">history</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span> <span class="n">train_ds</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="n">EPOCHS</span><span class="p">,</span> <span class="n">steps_per_epoch</span><span class="o">=</span><span class="mi">50</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="n">val_ds</span><span class="p">,</span> <span class="n">validation_steps</span><span class="o">=</span><span class="mi">50</span><span class="p">,</span> <span class="n">callbacks</span><span class="o">=</span><span class="p">[</span><span class="n">TrainMonitor</span><span class="p">()],</span> <span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg19/vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5 80142336/80134624 [==============================] - 1s 0us/step 80150528/80134624 [==============================] - 1s 0us/step 50/50 [==============================] - ETA: 0s - style_loss: 213.1439 - content_loss: 141.1564 - total_loss: 354.3002 </code></pre></div> </div> <p><img alt="png" src="/img/examples/generative/adain/adain_23_1.png" /></p> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>50/50 [==============================] - 124s 2s/step - style_loss: 213.1439 - content_loss: 141.1564 - total_loss: 354.3002 - val_style_loss: 167.0819 - val_content_loss: 129.0497 - val_total_loss: 296.1316 </code></pre></div> </div> <hr /> <h2 id="inference">Inference</h2> <p>After we train the model, we now need to run inference with it. We will pass arbitrary content and style images from the test dataset and take a look at the output images.</p> <p><em>NOTE</em>: To try out the model on your own images, you can use this <a href="https://huggingface.co/spaces/ariG23498/nst">Hugging Face demo</a>.</p> <div class="codehilite"><pre><span></span><code><span class="k">for</span> <span class="n">style</span><span class="p">,</span> <span class="n">content</span> <span class="ow">in</span> <span class="n">test_ds</span><span class="o">.</span><span class="n">take</span><span class="p">(</span><span class="mi">1</span><span class="p">):</span> <span class="n">style_encoded</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">encoder</span><span class="p">(</span><span class="n">style</span><span class="p">)</span> <span class="n">content_encoded</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">encoder</span><span class="p">(</span><span class="n">content</span><span class="p">)</span> <span class="n">t</span> <span class="o">=</span> <span class="n">ada_in</span><span class="p">(</span><span class="n">style</span><span class="o">=</span><span class="n">style_encoded</span><span class="p">,</span> <span class="n">content</span><span class="o">=</span><span class="n">content_encoded</span><span class="p">)</span> <span class="n">reconstructed_image</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">decoder</span><span class="p">(</span><span class="n">t</span><span class="p">)</span> <span class="n">fig</span><span class="p">,</span> <span class="n">axes</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">(</span><span class="n">nrows</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">ncols</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">30</span><span class="p">))</span> <span class="p">[</span><span class="n">ax</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">"off"</span><span class="p">)</span> <span class="k">for</span> <span class="n">ax</span> <span class="ow">in</span> <span class="n">np</span><span class="o">.</span><span class="n">ravel</span><span class="p">(</span><span class="n">axes</span><span class="p">)]</span> <span class="k">for</span> <span class="n">axis</span><span class="p">,</span> <span class="n">style_image</span><span class="p">,</span> <span class="n">content_image</span><span class="p">,</span> <span class="n">reconstructed_image</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span> <span class="n">axes</span><span class="p">,</span> <span class="n">style</span><span class="p">[</span><span class="mi">0</span><span class="p">:</span><span class="mi">10</span><span class="p">],</span> <span class="n">content</span><span class="p">[</span><span class="mi">0</span><span class="p">:</span><span class="mi">10</span><span class="p">],</span> <span class="n">reconstructed_image</span><span class="p">[</span><span class="mi">0</span><span class="p">:</span><span class="mi">10</span><span class="p">]</span> <span class="p">):</span> <span class="p">(</span><span class="n">ax_style</span><span class="p">,</span> <span class="n">ax_content</span><span class="p">,</span> <span class="n">ax_reconstructed</span><span class="p">)</span> <span class="o">=</span> <span class="n">axis</span> <span class="n">ax_style</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">style_image</span><span class="p">)</span> <span class="n">ax_style</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">"Style Image"</span><span class="p">)</span> <span class="n">ax_content</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">content_image</span><span class="p">)</span> <span class="n">ax_content</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">"Content Image"</span><span class="p">)</span> <span class="n">ax_reconstructed</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">reconstructed_image</span><span class="p">)</span> <span class="n">ax_reconstructed</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">"NST Image"</span><span class="p">)</span> </code></pre></div> <p><img alt="png" src="/img/examples/generative/adain/adain_25_0.png" /></p> <hr /> <h2 id="conclusion">Conclusion</h2> <p>Adaptive Instance Normalization allows arbitrary style transfer in real time. It is also important to note that the novel proposition of the authors is to achieve this only by aligning the statistical features (mean and standard deviation) of the style and the content images.</p> <p><em>Note</em>: AdaIN also serves as the base for <a href="https://arxiv.org/abs/1812.04948">Style-GANs</a>.</p> <hr /> <h2 id="reference">Reference</h2> <ul> <li><a href="https://github.com/ftokarev/tf-adain">TF implementation</a></li> </ul> <hr /> <h2 id="acknowledgement">Acknowledgement</h2> <p>We thank <a href="https://lukewood.xyz">Luke Wood</a> for his detailed review.</p> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#neural-style-transfer-with-adain'>Neural Style Transfer with AdaIN</a> </div> <div class='k-outline-depth-1'> <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-1'> <a href='#setup'>Setup</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#style-transfer-sample-gallery'>Style transfer sample gallery</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#downloading-the-dataset-from-kaggle'>Downloading the dataset from Kaggle</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#tfdata-pipeline'><code>tf.data</code> pipeline</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#visualizing-the-data'>Visualizing the data</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#architecture'>Architecture</a> </div> <div class='k-outline-depth-3'> <a href='#encoder'>Encoder</a> </div> <div class='k-outline-depth-3'> <a href='#adaptive-instance-normalization'>Adaptive Instance Normalization</a> </div> <div class='k-outline-depth-3'> <a href='#decoder'>Decoder</a> </div> <div class='k-outline-depth-3'> <a href='#loss-functions'>Loss functions</a> </div> <div class='k-outline-depth-3'> <a href='#content-loss'>Content Loss</a> </div> <div class='k-outline-depth-3'> <a href='#style-loss'>Style Loss</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#neural-style-transfer'>Neural Style Transfer</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#train-monitor-callback'>Train Monitor callback</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#train-the-model'>Train the model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#inference'>Inference</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#conclusion'>Conclusion</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#reference'>Reference</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#acknowledgement'>Acknowledgement</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>