CINXE.COM
Enhanced Deep Residual Networks for single-image super-resolution
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/vision/edsr/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Enhanced Deep Residual Networks for single-image super-resolution"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Enhanced Deep Residual Networks for single-image super-resolution"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Enhanced Deep Residual Networks for single-image super-resolution</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink active" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink2" href="/examples/vision/image_classification_from_scratch/">Image classification from scratch</a> <a class="nav-sublink2" href="/examples/vision/mnist_convnet/">Simple MNIST convnet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_efficientnet_fine_tuning/">Image classification via fine-tuning with EfficientNet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_with_vision_transformer/">Image classification with Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/attention_mil_classification/">Classification using Attention-based Deep Multiple Instance Learning</a> <a class="nav-sublink2" href="/examples/vision/mlp_image_classification/">Image classification with modern MLP models</a> <a class="nav-sublink2" href="/examples/vision/mobilevit/">A mobile-friendly Transformer-based model for image classification</a> <a class="nav-sublink2" href="/examples/vision/xray_classification_with_tpus/">Pneumonia Classification on TPU</a> <a class="nav-sublink2" href="/examples/vision/cct/">Compact Convolutional Transformers</a> <a class="nav-sublink2" href="/examples/vision/convmixer/">Image classification with ConvMixer</a> <a class="nav-sublink2" href="/examples/vision/eanet/">Image classification with EANet (External Attention Transformer)</a> <a class="nav-sublink2" href="/examples/vision/involution/">Involutional neural networks</a> <a class="nav-sublink2" href="/examples/vision/perceiver_image_classification/">Image classification with Perceiver</a> <a class="nav-sublink2" href="/examples/vision/reptile/">Few-Shot learning with Reptile</a> <a class="nav-sublink2" href="/examples/vision/semisupervised_simclr/">Semi-supervised image classification using contrastive pretraining with SimCLR</a> <a class="nav-sublink2" href="/examples/vision/swin_transformers/">Image classification with Swin Transformers</a> <a class="nav-sublink2" href="/examples/vision/vit_small_ds/">Train a Vision Transformer on small datasets</a> <a class="nav-sublink2" href="/examples/vision/shiftvit/">A Vision Transformer without Attention</a> <a class="nav-sublink2" href="/examples/vision/image_classification_using_global_context_vision_transformer/">Image Classification using Global Context Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/oxford_pets_image_segmentation/">Image segmentation with a U-Net-like architecture</a> <a class="nav-sublink2" href="/examples/vision/deeplabv3_plus/">Multiclass semantic segmentation using DeepLabV3+</a> <a class="nav-sublink2" href="/examples/vision/basnet_segmentation/">Highly accurate boundaries segmentation using BASNet</a> <a class="nav-sublink2" href="/examples/vision/fully_convolutional_network/">Image Segmentation using Composable Fully-Convolutional Networks</a> <a class="nav-sublink2" href="/examples/vision/retinanet/">Object Detection with RetinaNet</a> <a class="nav-sublink2" href="/examples/vision/keypoint_detection/">Keypoint Detection with Transfer Learning</a> <a class="nav-sublink2" href="/examples/vision/object_detection_using_vision_transformer/">Object detection with Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/3D_image_classification/">3D image classification from CT scans</a> <a class="nav-sublink2" href="/examples/vision/depth_estimation/">Monocular depth estimation</a> <a class="nav-sublink2" href="/examples/vision/nerf/">3D volumetric rendering with NeRF</a> <a class="nav-sublink2" href="/examples/vision/pointnet_segmentation/">Point cloud segmentation with PointNet</a> <a class="nav-sublink2" href="/examples/vision/pointnet/">Point cloud classification</a> <a class="nav-sublink2" href="/examples/vision/captcha_ocr/">OCR model for reading Captchas</a> <a class="nav-sublink2" href="/examples/vision/handwriting_recognition/">Handwriting recognition</a> <a class="nav-sublink2" href="/examples/vision/autoencoder/">Convolutional autoencoder for image denoising</a> <a class="nav-sublink2" href="/examples/vision/mirnet/">Low-light image enhancement using MIRNet</a> <a class="nav-sublink2" href="/examples/vision/super_resolution_sub_pixel/">Image Super-Resolution using an Efficient Sub-Pixel CNN</a> <a class="nav-sublink2 active" href="/examples/vision/edsr/">Enhanced Deep Residual Networks for single-image super-resolution</a> <a class="nav-sublink2" href="/examples/vision/zero_dce/">Zero-DCE for low-light image enhancement</a> <a class="nav-sublink2" href="/examples/vision/cutmix/">CutMix data augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/mixup/">MixUp augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/randaugment/">RandAugment for Image Classification for Improved Robustness</a> <a class="nav-sublink2" href="/examples/vision/image_captioning/">Image captioning</a> <a class="nav-sublink2" href="/examples/vision/nl_image_search/">Natural language image search with a Dual Encoder</a> <a class="nav-sublink2" href="/examples/vision/visualizing_what_convnets_learn/">Visualizing what convnets learn</a> <a class="nav-sublink2" href="/examples/vision/integrated_gradients/">Model interpretability with Integrated Gradients</a> <a class="nav-sublink2" href="/examples/vision/probing_vits/">Investigating Vision Transformer representations</a> <a class="nav-sublink2" href="/examples/vision/grad_cam/">Grad-CAM class activation visualization</a> <a class="nav-sublink2" href="/examples/vision/near_dup_search/">Near-duplicate image search</a> <a class="nav-sublink2" href="/examples/vision/semantic_image_clustering/">Semantic Image Clustering</a> <a class="nav-sublink2" href="/examples/vision/siamese_contrastive/">Image similarity estimation using a Siamese Network with a contrastive loss</a> <a class="nav-sublink2" href="/examples/vision/siamese_network/">Image similarity estimation using a Siamese Network with a triplet loss</a> <a class="nav-sublink2" href="/examples/vision/metric_learning/">Metric learning for image similarity search</a> <a class="nav-sublink2" href="/examples/vision/metric_learning_tf_similarity/">Metric learning for image similarity search using TensorFlow Similarity</a> <a class="nav-sublink2" href="/examples/vision/nnclr/">Self-supervised contrastive learning with NNCLR</a> <a class="nav-sublink2" href="/examples/vision/video_classification/">Video Classification with a CNN-RNN Architecture</a> <a class="nav-sublink2" href="/examples/vision/conv_lstm/">Next-Frame Video Prediction with Convolutional LSTMs</a> <a class="nav-sublink2" href="/examples/vision/video_transformers/">Video Classification with Transformers</a> <a class="nav-sublink2" href="/examples/vision/vivit/">Video Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/bit/">Image Classification using BigTransfer (BiT)</a> <a class="nav-sublink2" href="/examples/vision/gradient_centralization/">Gradient Centralization for Better Training Performance</a> <a class="nav-sublink2" href="/examples/vision/token_learner/">Learning to tokenize in Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/knowledge_distillation/">Knowledge Distillation</a> <a class="nav-sublink2" href="/examples/vision/fixres/">FixRes: Fixing train-test resolution discrepancy</a> <a class="nav-sublink2" href="/examples/vision/cait/">Class Attention Image Transformers with LayerScale</a> <a class="nav-sublink2" href="/examples/vision/patch_convnet/">Augmenting convnets with aggregated attention</a> <a class="nav-sublink2" href="/examples/vision/learnable_resizer/">Learning to Resize</a> <a class="nav-sublink2" href="/examples/vision/adamatch/">Semi-supervision and domain adaptation with AdaMatch</a> <a class="nav-sublink2" href="/examples/vision/barlow_twins/">Barlow Twins for Contrastive SSL</a> <a class="nav-sublink2" href="/examples/vision/consistency_training/">Consistency training with supervision</a> <a class="nav-sublink2" href="/examples/vision/deit/">Distilling Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/focal_modulation_network/">Focal Modulation: A replacement for Self-Attention</a> <a class="nav-sublink2" href="/examples/vision/forwardforward/">Using the Forward-Forward Algorithm for Image Classification</a> <a class="nav-sublink2" href="/examples/vision/masked_image_modeling/">Masked image modeling with Autoencoders</a> <a class="nav-sublink2" href="/examples/vision/sam/">Segment Anything Model with 🤗Transformers</a> <a class="nav-sublink2" href="/examples/vision/segformer/">Semantic segmentation with SegFormer and Hugging Face Transformers</a> <a class="nav-sublink2" href="/examples/vision/simsiam/">Self-supervised contrastive learning with SimSiam</a> <a class="nav-sublink2" href="/examples/vision/supervised-contrastive-learning/">Supervised Contrastive Learning</a> <a class="nav-sublink2" href="/examples/vision/temporal_latent_bottleneck/">When Recurrence meets Transformers</a> <a class="nav-sublink2" href="/examples/vision/yolov8/">Efficient Object Detection with YOLOV8 and KerasCV</a> <a class="nav-sublink" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparam Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/vision/'>Computer Vision</a> / Enhanced Deep Residual Networks for single-image super-resolution </div> <div class='k-content'> <h1 id="enhanced-deep-residual-networks-for-singleimage-superresolution">Enhanced Deep Residual Networks for single-image super-resolution</h1> <p><strong>Author:</strong> Gitesh Chawda<br> <strong>Date created:</strong> 2022/04/07<br> <strong>Last modified:</strong> 2024/08/27<br> <strong>Description:</strong> Training an EDSR model on the DIV2K Dataset.</p> <div class='example_version_banner keras_3'>ⓘ This example uses Keras 3</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/vision/ipynb/edsr.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/examples/vision/edsr.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="introduction">Introduction</h2> <p>In this example, we implement <a href="https://arxiv.org/abs/1707.02921">Enhanced Deep Residual Networks for Single Image Super-Resolution (EDSR)</a> by Bee Lim, Sanghyun Son, Heewon Kim, Seungjun Nah, and Kyoung Mu Lee.</p> <p>The EDSR architecture is based on the SRResNet architecture and consists of multiple residual blocks. It uses constant scaling layers instead of batch normalization layers to produce consistent results (input and output have similar distributions, thus normalizing intermediate features may not be desirable). Instead of using a L2 loss (mean squared error), the authors employed an L1 loss (mean absolute error), which performs better empirically.</p> <p>Our implementation only includes 16 residual blocks with 64 channels.</p> <p>Alternatively, as shown in the Keras example <a href="https://keras.io/examples/vision/super_resolution_sub_pixel/#image-superresolution-using-an-efficient-subpixel-cnn">Image Super-Resolution using an Efficient Sub-Pixel CNN</a>, you can do super-resolution using an ESPCN Model. According to the survey paper, EDSR is one of the top-five best-performing super-resolution methods based on PSNR scores. However, it has more parameters and requires more computational power than other approaches. It has a PSNR value (≈34db) that is slightly higher than ESPCN (≈32db). As per the survey paper, EDSR performs better than ESPCN.</p> <p>Paper: <a href="https://arxiv.org/abs/2102.09351">A comprehensive review of deep learning based single image super-resolution</a></p> <p>Comparison Graph: <img src="https://dfzljdn9uc3pi.cloudfront.net/2021/cs-621/1/fig-11-2x.jpg" width="500" /></p> <hr /> <h2 id="imports">Imports</h2> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">os</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s2">"KERAS_BACKEND"</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"tensorflow"</span> <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> <span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="nn">tf</span> <span class="kn">import</span> <span class="nn">tensorflow_datasets</span> <span class="k">as</span> <span class="nn">tfds</span> <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span> <span class="kn">import</span> <span class="nn">keras</span> <span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">layers</span> <span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">ops</span> <span class="n">AUTOTUNE</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">AUTOTUNE</span> </code></pre></div> <hr /> <h2 id="download-the-training-dataset">Download the training dataset</h2> <p>We use the DIV2K Dataset, a prominent single-image super-resolution dataset with 1,000 images of scenes with various sorts of degradations, divided into 800 images for training, 100 images for validation, and 100 images for testing. We use 4x bicubic downsampled images as our "low quality" reference.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Download DIV2K from TF Datasets</span> <span class="c1"># Using bicubic 4x degradation type</span> <span class="n">div2k_data</span> <span class="o">=</span> <span class="n">tfds</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">Div2k</span><span class="p">(</span><span class="n">config</span><span class="o">=</span><span class="s2">"bicubic_x4"</span><span class="p">)</span> <span class="n">div2k_data</span><span class="o">.</span><span class="n">download_and_prepare</span><span class="p">()</span> <span class="c1"># Taking train data from div2k_data object</span> <span class="n">train</span> <span class="o">=</span> <span class="n">div2k_data</span><span class="o">.</span><span class="n">as_dataset</span><span class="p">(</span><span class="n">split</span><span class="o">=</span><span class="s2">"train"</span><span class="p">,</span> <span class="n">as_supervised</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="n">train_cache</span> <span class="o">=</span> <span class="n">train</span><span class="o">.</span><span class="n">cache</span><span class="p">()</span> <span class="c1"># Validation data</span> <span class="n">val</span> <span class="o">=</span> <span class="n">div2k_data</span><span class="o">.</span><span class="n">as_dataset</span><span class="p">(</span><span class="n">split</span><span class="o">=</span><span class="s2">"validation"</span><span class="p">,</span> <span class="n">as_supervised</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="n">val_cache</span> <span class="o">=</span> <span class="n">val</span><span class="o">.</span><span class="n">cache</span><span class="p">()</span> </code></pre></div> <hr /> <h2 id="flip-crop-and-resize-images">Flip, crop and resize images</h2> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">flip_left_right</span><span class="p">(</span><span class="n">lowres_img</span><span class="p">,</span> <span class="n">highres_img</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Flips Images to left and right."""</span> <span class="c1"># Outputs random values from a uniform distribution in between 0 to 1</span> <span class="n">rn</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(),</span> <span class="n">maxval</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># If rn is less than 0.5 it returns original lowres_img and highres_img</span> <span class="c1"># If rn is greater than 0.5 it returns flipped image</span> <span class="k">return</span> <span class="n">ops</span><span class="o">.</span><span class="n">cond</span><span class="p">(</span> <span class="n">rn</span> <span class="o"><</span> <span class="mf">0.5</span><span class="p">,</span> <span class="k">lambda</span><span class="p">:</span> <span class="p">(</span><span class="n">lowres_img</span><span class="p">,</span> <span class="n">highres_img</span><span class="p">),</span> <span class="k">lambda</span><span class="p">:</span> <span class="p">(</span> <span class="n">ops</span><span class="o">.</span><span class="n">flip</span><span class="p">(</span><span class="n">lowres_img</span><span class="p">),</span> <span class="n">ops</span><span class="o">.</span><span class="n">flip</span><span class="p">(</span><span class="n">highres_img</span><span class="p">),</span> <span class="p">),</span> <span class="p">)</span> <span class="k">def</span> <span class="nf">random_rotate</span><span class="p">(</span><span class="n">lowres_img</span><span class="p">,</span> <span class="n">highres_img</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Rotates Images by 90 degrees."""</span> <span class="c1"># Outputs random values from uniform distribution in between 0 to 4</span> <span class="n">rn</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span> <span class="n">keras</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(),</span> <span class="n">maxval</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"float32"</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"int32"</span> <span class="p">)</span> <span class="c1"># Here rn signifies number of times the image(s) are rotated by 90 degrees</span> <span class="k">return</span> <span class="n">tf</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">rot90</span><span class="p">(</span><span class="n">lowres_img</span><span class="p">,</span> <span class="n">rn</span><span class="p">),</span> <span class="n">tf</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">rot90</span><span class="p">(</span><span class="n">highres_img</span><span class="p">,</span> <span class="n">rn</span><span class="p">)</span> <span class="k">def</span> <span class="nf">random_crop</span><span class="p">(</span><span class="n">lowres_img</span><span class="p">,</span> <span class="n">highres_img</span><span class="p">,</span> <span class="n">hr_crop_size</span><span class="o">=</span><span class="mi">96</span><span class="p">,</span> <span class="n">scale</span><span class="o">=</span><span class="mi">4</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Crop images.</span> <span class="sd"> low resolution images: 24x24</span> <span class="sd"> high resolution images: 96x96</span> <span class="sd"> """</span> <span class="n">lowres_crop_size</span> <span class="o">=</span> <span class="n">hr_crop_size</span> <span class="o">//</span> <span class="n">scale</span> <span class="c1"># 96//4=24</span> <span class="n">lowres_img_shape</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">lowres_img</span><span class="p">)[:</span><span class="mi">2</span><span class="p">]</span> <span class="c1"># (height,width)</span> <span class="n">lowres_width</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span> <span class="n">keras</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(),</span> <span class="n">maxval</span><span class="o">=</span><span class="n">lowres_img_shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">-</span> <span class="n">lowres_crop_size</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"float32"</span> <span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"int32"</span><span class="p">,</span> <span class="p">)</span> <span class="n">lowres_height</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span> <span class="n">keras</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(),</span> <span class="n">maxval</span><span class="o">=</span><span class="n">lowres_img_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">-</span> <span class="n">lowres_crop_size</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"float32"</span> <span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"int32"</span><span class="p">,</span> <span class="p">)</span> <span class="n">highres_width</span> <span class="o">=</span> <span class="n">lowres_width</span> <span class="o">*</span> <span class="n">scale</span> <span class="n">highres_height</span> <span class="o">=</span> <span class="n">lowres_height</span> <span class="o">*</span> <span class="n">scale</span> <span class="n">lowres_img_cropped</span> <span class="o">=</span> <span class="n">lowres_img</span><span class="p">[</span> <span class="n">lowres_height</span> <span class="p">:</span> <span class="n">lowres_height</span> <span class="o">+</span> <span class="n">lowres_crop_size</span><span class="p">,</span> <span class="n">lowres_width</span> <span class="p">:</span> <span class="n">lowres_width</span> <span class="o">+</span> <span class="n">lowres_crop_size</span><span class="p">,</span> <span class="p">]</span> <span class="c1"># 24x24</span> <span class="n">highres_img_cropped</span> <span class="o">=</span> <span class="n">highres_img</span><span class="p">[</span> <span class="n">highres_height</span> <span class="p">:</span> <span class="n">highres_height</span> <span class="o">+</span> <span class="n">hr_crop_size</span><span class="p">,</span> <span class="n">highres_width</span> <span class="p">:</span> <span class="n">highres_width</span> <span class="o">+</span> <span class="n">hr_crop_size</span><span class="p">,</span> <span class="p">]</span> <span class="c1"># 96x96</span> <span class="k">return</span> <span class="n">lowres_img_cropped</span><span class="p">,</span> <span class="n">highres_img_cropped</span> </code></pre></div> <hr /> <h2 id="tfdatadataset">Prepare a <a href="https://www.tensorflow.org/api_docs/python/tf/data/Dataset"><code>tf.data.Dataset</code></a> object</h2> <p>We augment the training data with random horizontal flips and 90 rotations.</p> <p>As low resolution images, we use 24x24 RGB input patches.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">dataset_object</span><span class="p">(</span><span class="n">dataset_cache</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span> <span class="n">ds</span> <span class="o">=</span> <span class="n">dataset_cache</span> <span class="n">ds</span> <span class="o">=</span> <span class="n">ds</span><span class="o">.</span><span class="n">map</span><span class="p">(</span> <span class="k">lambda</span> <span class="n">lowres</span><span class="p">,</span> <span class="n">highres</span><span class="p">:</span> <span class="n">random_crop</span><span class="p">(</span><span class="n">lowres</span><span class="p">,</span> <span class="n">highres</span><span class="p">,</span> <span class="n">scale</span><span class="o">=</span><span class="mi">4</span><span class="p">),</span> <span class="n">num_parallel_calls</span><span class="o">=</span><span class="n">AUTOTUNE</span><span class="p">,</span> <span class="p">)</span> <span class="k">if</span> <span class="n">training</span><span class="p">:</span> <span class="n">ds</span> <span class="o">=</span> <span class="n">ds</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">random_rotate</span><span class="p">,</span> <span class="n">num_parallel_calls</span><span class="o">=</span><span class="n">AUTOTUNE</span><span class="p">)</span> <span class="n">ds</span> <span class="o">=</span> <span class="n">ds</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">flip_left_right</span><span class="p">,</span> <span class="n">num_parallel_calls</span><span class="o">=</span><span class="n">AUTOTUNE</span><span class="p">)</span> <span class="c1"># Batching Data</span> <span class="n">ds</span> <span class="o">=</span> <span class="n">ds</span><span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="mi">16</span><span class="p">)</span> <span class="k">if</span> <span class="n">training</span><span class="p">:</span> <span class="c1"># Repeating Data, so that cardinality if dataset becomes infinte</span> <span class="n">ds</span> <span class="o">=</span> <span class="n">ds</span><span class="o">.</span><span class="n">repeat</span><span class="p">()</span> <span class="c1"># prefetching allows later images to be prepared while the current image is being processed</span> <span class="n">ds</span> <span class="o">=</span> <span class="n">ds</span><span class="o">.</span><span class="n">prefetch</span><span class="p">(</span><span class="n">buffer_size</span><span class="o">=</span><span class="n">AUTOTUNE</span><span class="p">)</span> <span class="k">return</span> <span class="n">ds</span> <span class="n">train_ds</span> <span class="o">=</span> <span class="n">dataset_object</span><span class="p">(</span><span class="n">train_cache</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="n">val_ds</span> <span class="o">=</span> <span class="n">dataset_object</span><span class="p">(</span><span class="n">val_cache</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="visualize-the-data">Visualize the data</h2> <p>Let's visualize a few sample images:</p> <div class="codehilite"><pre><span></span><code><span class="n">lowres</span><span class="p">,</span> <span class="n">highres</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="nb">iter</span><span class="p">(</span><span class="n">train_ds</span><span class="p">))</span> <span class="c1"># High Resolution Images</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">9</span><span class="p">):</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">highres</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">"uint8"</span><span class="p">))</span> <span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="n">highres</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">"off"</span><span class="p">)</span> <span class="c1"># Low Resolution Images</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">9</span><span class="p">):</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">lowres</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">"uint8"</span><span class="p">))</span> <span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="n">lowres</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">"off"</span><span class="p">)</span> <span class="k">def</span> <span class="nf">PSNR</span><span class="p">(</span><span class="n">super_resolution</span><span class="p">,</span> <span class="n">high_resolution</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Compute the peak signal-to-noise ratio, measures quality of image."""</span> <span class="c1"># Max value of pixel is 255</span> <span class="n">psnr_value</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">psnr</span><span class="p">(</span><span class="n">high_resolution</span><span class="p">,</span> <span class="n">super_resolution</span><span class="p">,</span> <span class="n">max_val</span><span class="o">=</span><span class="mi">255</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span> <span class="k">return</span> <span class="n">psnr_value</span> </code></pre></div> <p><img alt="png" src="/img/examples/vision/edsr/edsr_11_0.png" /></p> <p><img alt="png" src="/img/examples/vision/edsr/edsr_11_1.png" /></p> <hr /> <h2 id="build-the-model">Build the model</h2> <p>In the paper, the authors train three models: EDSR, MDSR, and a baseline model. In this code example, we only train the baseline model.</p> <h3 id="comparison-with-model-with-three-residual-blocks">Comparison with model with three residual blocks</h3> <p>The residual block design of EDSR differs from that of ResNet. Batch normalization layers have been removed (together with the final ReLU activation): since batch normalization layers normalize the features, they hurt output value range flexibility. It is thus better to remove them. Further, it also helps reduce the amount of GPU RAM required by the model, since the batch normalization layers consume the same amount of memory as the preceding convolutional layers.</p> <p><img src="https://miro.medium.com/max/1050/1*EPviXGqlGWotVtV2gqVvNg.png" width="500" /></p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">EDSRModel</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">):</span> <span class="k">def</span> <span class="nf">train_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">):</span> <span class="c1"># Unpack the data. Its structure depends on your model and</span> <span class="c1"># on what you pass to `fit()`.</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">data</span> <span class="k">with</span> <span class="n">tf</span><span class="o">.</span><span class="n">GradientTape</span><span class="p">()</span> <span class="k">as</span> <span class="n">tape</span><span class="p">:</span> <span class="n">y_pred</span> <span class="o">=</span> <span class="bp">self</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="c1"># Forward pass</span> <span class="c1"># Compute the loss value</span> <span class="c1"># (the loss function is configured in `compile()`)</span> <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">compiled_loss</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">,</span> <span class="n">regularization_losses</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">losses</span><span class="p">)</span> <span class="c1"># Compute gradients</span> <span class="n">trainable_vars</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">trainable_variables</span> <span class="n">gradients</span> <span class="o">=</span> <span class="n">tape</span><span class="o">.</span><span class="n">gradient</span><span class="p">(</span><span class="n">loss</span><span class="p">,</span> <span class="n">trainable_vars</span><span class="p">)</span> <span class="c1"># Update weights</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="n">gradients</span><span class="p">,</span> <span class="n">trainable_vars</span><span class="p">))</span> <span class="c1"># Update metrics (includes the metric that tracks the loss)</span> <span class="bp">self</span><span class="o">.</span><span class="n">compiled_metrics</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">)</span> <span class="c1"># Return a dict mapping metric names to current value</span> <span class="k">return</span> <span class="p">{</span><span class="n">m</span><span class="o">.</span><span class="n">name</span><span class="p">:</span> <span class="n">m</span><span class="o">.</span><span class="n">result</span><span class="p">()</span> <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">metrics</span><span class="p">}</span> <span class="k">def</span> <span class="nf">predict_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="c1"># Adding dummy dimension using tf.expand_dims and converting to float32 using tf.cast</span> <span class="n">x</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"float32"</span><span class="p">)</span> <span class="c1"># Passing low resolution image to model</span> <span class="n">super_resolution_img</span> <span class="o">=</span> <span class="bp">self</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> <span class="c1"># Clips the tensor from min(0) to max(255)</span> <span class="n">super_resolution_img</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">super_resolution_img</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">255</span><span class="p">)</span> <span class="c1"># Rounds the values of a tensor to the nearest integer</span> <span class="n">super_resolution_img</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">round</span><span class="p">(</span><span class="n">super_resolution_img</span><span class="p">)</span> <span class="c1"># Removes dimensions of size 1 from the shape of a tensor and converting to uint8</span> <span class="n">super_resolution_img</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span> <span class="n">ops</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">super_resolution_img</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"uint8"</span><span class="p">),</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span> <span class="p">)</span> <span class="k">return</span> <span class="n">super_resolution_img</span> <span class="c1"># Residual Block</span> <span class="k">def</span> <span class="nf">ResBlock</span><span class="p">(</span><span class="n">inputs</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"same"</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"same"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Add</span><span class="p">()([</span><span class="n">inputs</span><span class="p">,</span> <span class="n">x</span><span class="p">])</span> <span class="k">return</span> <span class="n">x</span> <span class="c1"># Upsampling Block</span> <span class="k">def</span> <span class="nf">Upsampling</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">factor</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">64</span> <span class="o">*</span> <span class="p">(</span><span class="n">factor</span><span class="o">**</span><span class="mi">2</span><span class="p">),</span> <span class="mi">3</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"same"</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Lambda</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">tf</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">depth_to_space</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">block_size</span><span class="o">=</span><span class="n">factor</span><span class="p">))(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">64</span> <span class="o">*</span> <span class="p">(</span><span class="n">factor</span><span class="o">**</span><span class="mi">2</span><span class="p">),</span> <span class="mi">3</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"same"</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Lambda</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">tf</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">depth_to_space</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">block_size</span><span class="o">=</span><span class="n">factor</span><span class="p">))(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">x</span> <span class="k">def</span> <span class="nf">make_model</span><span class="p">(</span><span class="n">num_filters</span><span class="p">,</span> <span class="n">num_of_residual_blocks</span><span class="p">):</span> <span class="c1"># Flexible Inputs to input_layer</span> <span class="n">input_layer</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span> <span class="c1"># Scaling Pixel Values</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Rescaling</span><span class="p">(</span><span class="n">scale</span><span class="o">=</span><span class="mf">1.0</span> <span class="o">/</span> <span class="mi">255</span><span class="p">)(</span><span class="n">input_layer</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x_new</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="n">num_filters</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"same"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># 16 residual blocks</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_of_residual_blocks</span><span class="p">):</span> <span class="n">x_new</span> <span class="o">=</span> <span class="n">ResBlock</span><span class="p">(</span><span class="n">x_new</span><span class="p">)</span> <span class="n">x_new</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="n">num_filters</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"same"</span><span class="p">)(</span><span class="n">x_new</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Add</span><span class="p">()([</span><span class="n">x</span><span class="p">,</span> <span class="n">x_new</span><span class="p">])</span> <span class="n">x</span> <span class="o">=</span> <span class="n">Upsampling</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"same"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">output_layer</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Rescaling</span><span class="p">(</span><span class="n">scale</span><span class="o">=</span><span class="mi">255</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">EDSRModel</span><span class="p">(</span><span class="n">input_layer</span><span class="p">,</span> <span class="n">output_layer</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">make_model</span><span class="p">(</span><span class="n">num_filters</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">num_of_residual_blocks</span><span class="o">=</span><span class="mi">16</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="train-the-model">Train the model</h2> <div class="codehilite"><pre><span></span><code><span class="c1"># Using adam optimizer with initial learning rate as 1e-4, changing learning rate after 5000 steps to 5e-5</span> <span class="n">optim_edsr</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span> <span class="n">learning_rate</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">schedules</span><span class="o">.</span><span class="n">PiecewiseConstantDecay</span><span class="p">(</span> <span class="n">boundaries</span><span class="o">=</span><span class="p">[</span><span class="mi">5000</span><span class="p">],</span> <span class="n">values</span><span class="o">=</span><span class="p">[</span><span class="mf">1e-4</span><span class="p">,</span> <span class="mf">5e-5</span><span class="p">]</span> <span class="p">)</span> <span class="p">)</span> <span class="c1"># Compiling model with loss as mean absolute error(L1 Loss) and metric as psnr</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">optimizer</span><span class="o">=</span><span class="n">optim_edsr</span><span class="p">,</span> <span class="n">loss</span><span class="o">=</span><span class="s2">"mae"</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="n">PSNR</span><span class="p">])</span> <span class="c1"># Training for more epochs will improve results</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">train_ds</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">100</span><span class="p">,</span> <span class="n">steps_per_epoch</span><span class="o">=</span><span class="mi">200</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="n">val_ds</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 1/100 200/200 ━━━━━━━━━━━━━━━━━━━━ 117s 472ms/step - psnr: 8.7874 - loss: 85.1546 - val_loss: 17.4624 - val_psnr: 8.7008 Epoch 10/100 200/200 ━━━━━━━━━━━━━━━━━━━━ 58s 288ms/step - psnr: 8.9519 - loss: 94.4611 - val_loss: 8.6002 - val_psnr: 6.4303 Epoch 20/100 200/200 ━━━━━━━━━━━━━━━━━━━━ 52s 261ms/step - psnr: 8.5120 - loss: 95.5767 - val_loss: 8.7330 - val_psnr: 6.3106 Epoch 30/100 200/200 ━━━━━━━━━━━━━━━━━━━━ 53s 262ms/step - psnr: 8.6051 - loss: 96.1541 - val_loss: 7.5442 - val_psnr: 7.9715 Epoch 40/100 200/200 ━━━━━━━━━━━━━━━━━━━━ 53s 263ms/step - psnr: 8.7405 - loss: 96.8159 - val_loss: 7.2734 - val_psnr: 7.6312 Epoch 50/100 200/200 ━━━━━━━━━━━━━━━━━━━━ 52s 259ms/step - psnr: 8.7648 - loss: 95.7817 - val_loss: 8.1772 - val_psnr: 7.1330 Epoch 60/100 200/200 ━━━━━━━━━━━━━━━━━━━━ 53s 264ms/step - psnr: 8.8651 - loss: 95.4793 - val_loss: 7.6550 - val_psnr: 7.2298 Epoch 70/100 200/200 ━━━━━━━━━━━━━━━━━━━━ 53s 263ms/step - psnr: 8.8489 - loss: 94.5993 - val_loss: 7.4607 - val_psnr: 6.6841 Epoch 80/100 200/200 ━━━━━━━━━━━━━━━━━━━━ 53s 263ms/step - psnr: 8.3046 - loss: 97.3796 - val_loss: 8.1050 - val_psnr: 8.0714 Epoch 90/100 200/200 ━━━━━━━━━━━━━━━━━━━━ 53s 264ms/step - psnr: 7.9295 - loss: 96.0314 - val_loss: 7.1515 - val_psnr: 6.8712 Epoch 100/100 200/200 ━━━━━━━━━━━━━━━━━━━━ 53s 263ms/step - psnr: 8.1666 - loss: 94.9792 - val_loss: 6.6524 - val_psnr: 6.5423 <keras.src.callbacks.history.History at 0x7fc1e8dd6890> </code></pre></div> </div> <hr /> <h2 id="run-inference-on-new-images-and-plot-the-results">Run inference on new images and plot the results</h2> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">plot_results</span><span class="p">(</span><span class="n">lowres</span><span class="p">,</span> <span class="n">preds</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""</span> <span class="sd"> Displays low resolution image and super resolution image</span> <span class="sd"> """</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">24</span><span class="p">,</span> <span class="mi">14</span><span class="p">))</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">132</span><span class="p">),</span> <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">lowres</span><span class="p">),</span> <span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="s2">"Low resolution"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">133</span><span class="p">),</span> <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">preds</span><span class="p">),</span> <span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="s2">"Prediction"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> <span class="k">for</span> <span class="n">lowres</span><span class="p">,</span> <span class="n">highres</span> <span class="ow">in</span> <span class="n">val</span><span class="o">.</span><span class="n">take</span><span class="p">(</span><span class="mi">10</span><span class="p">):</span> <span class="n">lowres</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">random_crop</span><span class="p">(</span><span class="n">lowres</span><span class="p">,</span> <span class="p">(</span><span class="mi">150</span><span class="p">,</span> <span class="mi">150</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span> <span class="n">preds</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">predict_step</span><span class="p">(</span><span class="n">lowres</span><span class="p">)</span> <span class="n">plot_results</span><span class="p">(</span><span class="n">lowres</span><span class="p">,</span> <span class="n">preds</span><span class="p">)</span> </code></pre></div> <p><img alt="png" src="/img/examples/vision/edsr/edsr_17_0.png" /></p> <p><img alt="png" src="/img/examples/vision/edsr/edsr_17_1.png" /></p> <p><img alt="png" src="/img/examples/vision/edsr/edsr_17_2.png" /></p> <p><img alt="png" src="/img/examples/vision/edsr/edsr_17_3.png" /></p> <hr /> <h2 id="final-remarks">Final remarks</h2> <p>In this example, we implemented the EDSR model (Enhanced Deep Residual Networks for Single Image Super-Resolution). You could improve the model accuracy by training the model for more epochs, as well as training the model with a wider variety of inputs with mixed downgrading factors, so as to be able to handle a greater range of real-world images.</p> <p>You could also improve on the given baseline EDSR model by implementing EDSR+, or MDSR( Multi-Scale super-resolution) and MDSR+, which were proposed in the same paper.</p> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#enhanced-deep-residual-networks-for-singleimage-superresolution'>Enhanced Deep Residual Networks for single-image super-resolution</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#imports'>Imports</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#download-the-training-dataset'>Download the training dataset</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#flip-crop-and-resize-images'>Flip, crop and resize images</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#prepare-a-tfdatadataset-object'>Prepare a <code>tf.data.Dataset</code> object</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#visualize-the-data'>Visualize the data</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#build-the-model'>Build the model</a> </div> <div class='k-outline-depth-3'> <a href='#comparison-with-model-with-three-residual-blocks'>Comparison with model with three residual blocks</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#train-the-model'>Train the model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#run-inference-on-new-images-and-plot-the-results'>Run inference on new images and plot the results</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#final-remarks'>Final remarks</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>