CINXE.COM
Learning to Resize in Computer Vision
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/vision/learnable_resizer/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Learning to Resize in Computer Vision"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Learning to Resize in Computer Vision"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Learning to Resize in Computer Vision</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink active" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink2" href="/examples/vision/image_classification_from_scratch/">Image classification from scratch</a> <a class="nav-sublink2" href="/examples/vision/mnist_convnet/">Simple MNIST convnet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_efficientnet_fine_tuning/">Image classification via fine-tuning with EfficientNet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_with_vision_transformer/">Image classification with Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/attention_mil_classification/">Classification using Attention-based Deep Multiple Instance Learning</a> <a class="nav-sublink2" href="/examples/vision/mlp_image_classification/">Image classification with modern MLP models</a> <a class="nav-sublink2" href="/examples/vision/mobilevit/">A mobile-friendly Transformer-based model for image classification</a> <a class="nav-sublink2" href="/examples/vision/xray_classification_with_tpus/">Pneumonia Classification on TPU</a> <a class="nav-sublink2" href="/examples/vision/cct/">Compact Convolutional Transformers</a> <a class="nav-sublink2" href="/examples/vision/convmixer/">Image classification with ConvMixer</a> <a class="nav-sublink2" href="/examples/vision/eanet/">Image classification with EANet (External Attention Transformer)</a> <a class="nav-sublink2" href="/examples/vision/involution/">Involutional neural networks</a> <a class="nav-sublink2" href="/examples/vision/perceiver_image_classification/">Image classification with Perceiver</a> <a class="nav-sublink2" href="/examples/vision/reptile/">Few-Shot learning with Reptile</a> <a class="nav-sublink2" href="/examples/vision/semisupervised_simclr/">Semi-supervised image classification using contrastive pretraining with SimCLR</a> <a class="nav-sublink2" href="/examples/vision/swin_transformers/">Image classification with Swin Transformers</a> <a class="nav-sublink2" href="/examples/vision/vit_small_ds/">Train a Vision Transformer on small datasets</a> <a class="nav-sublink2" href="/examples/vision/shiftvit/">A Vision Transformer without Attention</a> <a class="nav-sublink2" href="/examples/vision/image_classification_using_global_context_vision_transformer/">Image Classification using Global Context Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/oxford_pets_image_segmentation/">Image segmentation with a U-Net-like architecture</a> <a class="nav-sublink2" href="/examples/vision/deeplabv3_plus/">Multiclass semantic segmentation using DeepLabV3+</a> <a class="nav-sublink2" href="/examples/vision/basnet_segmentation/">Highly accurate boundaries segmentation using BASNet</a> <a class="nav-sublink2" href="/examples/vision/fully_convolutional_network/">Image Segmentation using Composable Fully-Convolutional Networks</a> <a class="nav-sublink2" href="/examples/vision/retinanet/">Object Detection with RetinaNet</a> <a class="nav-sublink2" href="/examples/vision/keypoint_detection/">Keypoint Detection with Transfer Learning</a> <a class="nav-sublink2" href="/examples/vision/object_detection_using_vision_transformer/">Object detection with Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/3D_image_classification/">3D image classification from CT scans</a> <a class="nav-sublink2" href="/examples/vision/depth_estimation/">Monocular depth estimation</a> <a class="nav-sublink2" href="/examples/vision/nerf/">3D volumetric rendering with NeRF</a> <a class="nav-sublink2" href="/examples/vision/pointnet_segmentation/">Point cloud segmentation with PointNet</a> <a class="nav-sublink2" href="/examples/vision/pointnet/">Point cloud classification</a> <a class="nav-sublink2" href="/examples/vision/captcha_ocr/">OCR model for reading Captchas</a> <a class="nav-sublink2" href="/examples/vision/handwriting_recognition/">Handwriting recognition</a> <a class="nav-sublink2" href="/examples/vision/autoencoder/">Convolutional autoencoder for image denoising</a> <a class="nav-sublink2" href="/examples/vision/mirnet/">Low-light image enhancement using MIRNet</a> <a class="nav-sublink2" href="/examples/vision/super_resolution_sub_pixel/">Image Super-Resolution using an Efficient Sub-Pixel CNN</a> <a class="nav-sublink2" href="/examples/vision/edsr/">Enhanced Deep Residual Networks for single-image super-resolution</a> <a class="nav-sublink2" href="/examples/vision/zero_dce/">Zero-DCE for low-light image enhancement</a> <a class="nav-sublink2" href="/examples/vision/cutmix/">CutMix data augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/mixup/">MixUp augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/randaugment/">RandAugment for Image Classification for Improved Robustness</a> <a class="nav-sublink2" href="/examples/vision/image_captioning/">Image captioning</a> <a class="nav-sublink2" href="/examples/vision/nl_image_search/">Natural language image search with a Dual Encoder</a> <a class="nav-sublink2" href="/examples/vision/visualizing_what_convnets_learn/">Visualizing what convnets learn</a> <a class="nav-sublink2" href="/examples/vision/integrated_gradients/">Model interpretability with Integrated Gradients</a> <a class="nav-sublink2" href="/examples/vision/probing_vits/">Investigating Vision Transformer representations</a> <a class="nav-sublink2" href="/examples/vision/grad_cam/">Grad-CAM class activation visualization</a> <a class="nav-sublink2" href="/examples/vision/near_dup_search/">Near-duplicate image search</a> <a class="nav-sublink2" href="/examples/vision/semantic_image_clustering/">Semantic Image Clustering</a> <a class="nav-sublink2" href="/examples/vision/siamese_contrastive/">Image similarity estimation using a Siamese Network with a contrastive loss</a> <a class="nav-sublink2" href="/examples/vision/siamese_network/">Image similarity estimation using a Siamese Network with a triplet loss</a> <a class="nav-sublink2" href="/examples/vision/metric_learning/">Metric learning for image similarity search</a> <a class="nav-sublink2" href="/examples/vision/metric_learning_tf_similarity/">Metric learning for image similarity search using TensorFlow Similarity</a> <a class="nav-sublink2" href="/examples/vision/nnclr/">Self-supervised contrastive learning with NNCLR</a> <a class="nav-sublink2" href="/examples/vision/video_classification/">Video Classification with a CNN-RNN Architecture</a> <a class="nav-sublink2" href="/examples/vision/conv_lstm/">Next-Frame Video Prediction with Convolutional LSTMs</a> <a class="nav-sublink2" href="/examples/vision/video_transformers/">Video Classification with Transformers</a> <a class="nav-sublink2" href="/examples/vision/vivit/">Video Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/bit/">Image Classification using BigTransfer (BiT)</a> <a class="nav-sublink2" href="/examples/vision/gradient_centralization/">Gradient Centralization for Better Training Performance</a> <a class="nav-sublink2" href="/examples/vision/token_learner/">Learning to tokenize in Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/knowledge_distillation/">Knowledge Distillation</a> <a class="nav-sublink2" href="/examples/vision/fixres/">FixRes: Fixing train-test resolution discrepancy</a> <a class="nav-sublink2" href="/examples/vision/cait/">Class Attention Image Transformers with LayerScale</a> <a class="nav-sublink2" href="/examples/vision/patch_convnet/">Augmenting convnets with aggregated attention</a> <a class="nav-sublink2 active" href="/examples/vision/learnable_resizer/">Learning to Resize</a> <a class="nav-sublink2" href="/examples/vision/adamatch/">Semi-supervision and domain adaptation with AdaMatch</a> <a class="nav-sublink2" href="/examples/vision/barlow_twins/">Barlow Twins for Contrastive SSL</a> <a class="nav-sublink2" href="/examples/vision/consistency_training/">Consistency training with supervision</a> <a class="nav-sublink2" href="/examples/vision/deit/">Distilling Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/focal_modulation_network/">Focal Modulation: A replacement for Self-Attention</a> <a class="nav-sublink2" href="/examples/vision/forwardforward/">Using the Forward-Forward Algorithm for Image Classification</a> <a class="nav-sublink2" href="/examples/vision/masked_image_modeling/">Masked image modeling with Autoencoders</a> <a class="nav-sublink2" href="/examples/vision/sam/">Segment Anything Model with 🤗Transformers</a> <a class="nav-sublink2" href="/examples/vision/segformer/">Semantic segmentation with SegFormer and Hugging Face Transformers</a> <a class="nav-sublink2" href="/examples/vision/simsiam/">Self-supervised contrastive learning with SimSiam</a> <a class="nav-sublink2" href="/examples/vision/supervised-contrastive-learning/">Supervised Contrastive Learning</a> <a class="nav-sublink2" href="/examples/vision/temporal_latent_bottleneck/">When Recurrence meets Transformers</a> <a class="nav-sublink2" href="/examples/vision/yolov8/">Efficient Object Detection with YOLOV8 and KerasCV</a> <a class="nav-sublink" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparam Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/vision/'>Computer Vision</a> / Learning to Resize in Computer Vision </div> <div class='k-content'> <h1 id="learning-to-resize-in-computer-vision">Learning to Resize in Computer Vision</h1> <p><strong>Author:</strong> <a href="https://twitter.com/RisingSayak">Sayak Paul</a><br> <strong>Date created:</strong> 2021/04/30<br> <strong>Last modified:</strong> 2023/12/18<br> <strong>Description:</strong> How to optimally learn representations of images for a given resolution.</p> <div class='example_version_banner keras_3'>ⓘ This example uses Keras 3</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/vision/ipynb/learnable_resizer.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/examples/vision/learnable_resizer.py"><strong>GitHub source</strong></a></p> <p>It is a common belief that if we constrain vision models to perceive things as humans do, their performance can be improved. For example, in <a href="https://arxiv.org/abs/1811.12231">this work</a>, Geirhos et al. showed that the vision models pre-trained on the ImageNet-1k dataset are biased towards texture, whereas human beings mostly use the shape descriptor to develop a common perception. But does this belief always apply, especially when it comes to improving the performance of vision models?</p> <p>It turns out it may not always be the case. When training vision models, it is common to resize images to a lower dimension ((224 x 224), (299 x 299), etc.) to allow mini-batch learning and also to keep up the compute limitations. We generally make use of image resizing methods like <strong>bilinear interpolation</strong> for this step and the resized images do not lose much of their perceptual character to the human eyes. In <a href="https://arxiv.org/abs/2103.09950v1">Learning to Resize Images for Computer Vision Tasks</a>, Talebi et al. show that if we try to optimize the perceptual quality of the images for the vision models rather than the human eyes, their performance can further be improved. They investigate the following question:</p> <p><strong>For a given image resolution and a model, how to best resize the given images?</strong></p> <p>As shown in the paper, this idea helps to consistently improve the performance of the common vision models (pre-trained on ImageNet-1k) like DenseNet-121, ResNet-50, MobileNetV2, and EfficientNets. In this example, we will implement the learnable image resizing module as proposed in the paper and demonstrate that on the <a href="https://www.microsoft.com/en-us/download/details.aspx?id=54765">Cats and Dogs dataset</a> using the <a href="https://arxiv.org/abs/1608.06993">DenseNet-121</a> architecture.</p> <hr /> <h2 id="setup">Setup</h2> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">os</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s2">"KERAS_BACKEND"</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"tensorflow"</span> <span class="kn">import</span> <span class="nn">keras</span> <span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">ops</span> <span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">layers</span> <span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="nn">tf</span> <span class="kn">import</span> <span class="nn">tensorflow_datasets</span> <span class="k">as</span> <span class="nn">tfds</span> <span class="n">tfds</span><span class="o">.</span><span class="n">disable_progress_bar</span><span class="p">()</span> <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span> <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> </code></pre></div> <hr /> <h2 id="define-hyperparameters">Define hyperparameters</h2> <p>In order to facilitate mini-batch learning, we need to have a fixed shape for the images inside a given batch. This is why an initial resizing is required. We first resize all the images to (300 x 300) shape and then learn their optimal representation for the (150 x 150) resolution.</p> <div class="codehilite"><pre><span></span><code><span class="n">INP_SIZE</span> <span class="o">=</span> <span class="p">(</span><span class="mi">300</span><span class="p">,</span> <span class="mi">300</span><span class="p">)</span> <span class="n">TARGET_SIZE</span> <span class="o">=</span> <span class="p">(</span><span class="mi">150</span><span class="p">,</span> <span class="mi">150</span><span class="p">)</span> <span class="n">INTERPOLATION</span> <span class="o">=</span> <span class="s2">"bilinear"</span> <span class="n">AUTO</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">AUTOTUNE</span> <span class="n">BATCH_SIZE</span> <span class="o">=</span> <span class="mi">64</span> <span class="n">EPOCHS</span> <span class="o">=</span> <span class="mi">5</span> </code></pre></div> <p>In this example, we will use the bilinear interpolation but the learnable image resizer module is not dependent on any specific interpolation method. We can also use others, such as bicubic.</p> <hr /> <h2 id="load-and-prepare-the-dataset">Load and prepare the dataset</h2> <p>For this example, we will only use 40% of the total training dataset.</p> <div class="codehilite"><pre><span></span><code><span class="n">train_ds</span><span class="p">,</span> <span class="n">validation_ds</span> <span class="o">=</span> <span class="n">tfds</span><span class="o">.</span><span class="n">load</span><span class="p">(</span> <span class="s2">"cats_vs_dogs"</span><span class="p">,</span> <span class="c1"># Reserve 10% for validation</span> <span class="n">split</span><span class="o">=</span><span class="p">[</span><span class="s2">"train[:40%]"</span><span class="p">,</span> <span class="s2">"train[40%:50%]"</span><span class="p">],</span> <span class="n">as_supervised</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="p">)</span> <span class="k">def</span> <span class="nf">preprocess_dataset</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="n">label</span><span class="p">):</span> <span class="n">image</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">resize</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="p">(</span><span class="n">INP_SIZE</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">INP_SIZE</span><span class="p">[</span><span class="mi">1</span><span class="p">]))</span> <span class="n">label</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">one_hot</span><span class="p">(</span><span class="n">label</span><span class="p">,</span> <span class="n">num_classes</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span> <span class="k">return</span> <span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="n">label</span><span class="p">)</span> <span class="n">train_ds</span> <span class="o">=</span> <span class="p">(</span> <span class="n">train_ds</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="n">BATCH_SIZE</span> <span class="o">*</span> <span class="mi">100</span><span class="p">)</span> <span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">preprocess_dataset</span><span class="p">,</span> <span class="n">num_parallel_calls</span><span class="o">=</span><span class="n">AUTO</span><span class="p">)</span> <span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="n">BATCH_SIZE</span><span class="p">)</span> <span class="o">.</span><span class="n">prefetch</span><span class="p">(</span><span class="n">AUTO</span><span class="p">)</span> <span class="p">)</span> <span class="n">validation_ds</span> <span class="o">=</span> <span class="p">(</span> <span class="n">validation_ds</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">preprocess_dataset</span><span class="p">,</span> <span class="n">num_parallel_calls</span><span class="o">=</span><span class="n">AUTO</span><span class="p">)</span> <span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="n">BATCH_SIZE</span><span class="p">)</span> <span class="o">.</span><span class="n">prefetch</span><span class="p">(</span><span class="n">AUTO</span><span class="p">)</span> <span class="p">)</span> </code></pre></div> <hr /> <h2 id="define-the-learnable-resizer-utilities">Define the learnable resizer utilities</h2> <p>The figure below (courtesy: <a href="https://arxiv.org/abs/2103.09950v1">Learning to Resize Images for Computer Vision Tasks</a>) presents the structure of the learnable resizing module:</p> <p><img alt="" src="https://i.ibb.co/gJYtSs0/image.png" /></p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">conv_block</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">filters</span><span class="p">,</span> <span class="n">kernel_size</span><span class="p">,</span> <span class="n">strides</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="n">layers</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">(</span><span class="mf">0.2</span><span class="p">)):</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="n">filters</span><span class="p">,</span> <span class="n">kernel_size</span><span class="p">,</span> <span class="n">strides</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"same"</span><span class="p">,</span> <span class="n">use_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">BatchNormalization</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> <span class="k">if</span> <span class="n">activation</span><span class="p">:</span> <span class="n">x</span> <span class="o">=</span> <span class="n">activation</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">x</span> <span class="k">def</span> <span class="nf">res_block</span><span class="p">(</span><span class="n">x</span><span class="p">):</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">x</span> <span class="n">x</span> <span class="o">=</span> <span class="n">conv_block</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">16</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">conv_block</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">16</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span> <span class="k">return</span> <span class="n">layers</span><span class="o">.</span><span class="n">Add</span><span class="p">()([</span><span class="n">inputs</span><span class="p">,</span> <span class="n">x</span><span class="p">])</span> <span class="c1"># Note: user can change num_res_blocks to >1 also if needed</span> <span class="k">def</span> <span class="nf">get_learnable_resizer</span><span class="p">(</span><span class="n">filters</span><span class="o">=</span><span class="mi">16</span><span class="p">,</span> <span class="n">num_res_blocks</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">interpolation</span><span class="o">=</span><span class="n">INTERPOLATION</span><span class="p">):</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="mi">3</span><span class="p">])</span> <span class="c1"># First, perform naive resizing.</span> <span class="n">naive_resize</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Resizing</span><span class="p">(</span><span class="o">*</span><span class="n">TARGET_SIZE</span><span class="p">,</span> <span class="n">interpolation</span><span class="o">=</span><span class="n">interpolation</span><span class="p">)(</span><span class="n">inputs</span><span class="p">)</span> <span class="c1"># First convolution block without batch normalization.</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="n">filters</span><span class="o">=</span><span class="n">filters</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">7</span><span class="p">,</span> <span class="n">strides</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"same"</span><span class="p">)(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">(</span><span class="mf">0.2</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># Second convolution block with batch normalization.</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="n">filters</span><span class="o">=</span><span class="n">filters</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">strides</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"same"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">LeakyReLU</span><span class="p">(</span><span class="mf">0.2</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">BatchNormalization</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># Intermediate resizing as a bottleneck.</span> <span class="n">bottleneck</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Resizing</span><span class="p">(</span><span class="o">*</span><span class="n">TARGET_SIZE</span><span class="p">,</span> <span class="n">interpolation</span><span class="o">=</span><span class="n">interpolation</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># Residual passes.</span> <span class="c1"># First res_block will get bottleneck output as input</span> <span class="n">x</span> <span class="o">=</span> <span class="n">res_block</span><span class="p">(</span><span class="n">bottleneck</span><span class="p">)</span> <span class="c1"># Remaining res_blocks will get previous res_block output as input</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_res_blocks</span> <span class="o">-</span> <span class="mi">1</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="n">res_block</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># Projection.</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span> <span class="n">filters</span><span class="o">=</span><span class="n">filters</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">strides</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"same"</span><span class="p">,</span> <span class="n">use_bias</span><span class="o">=</span><span class="kc">False</span> <span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">BatchNormalization</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># Skip connection.</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Add</span><span class="p">()([</span><span class="n">bottleneck</span><span class="p">,</span> <span class="n">x</span><span class="p">])</span> <span class="c1"># Final resized image.</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="n">filters</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">7</span><span class="p">,</span> <span class="n">strides</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"same"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">final_resize</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Add</span><span class="p">()([</span><span class="n">naive_resize</span><span class="p">,</span> <span class="n">x</span><span class="p">])</span> <span class="k">return</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">final_resize</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"learnable_resizer"</span><span class="p">)</span> <span class="n">learnable_resizer</span> <span class="o">=</span> <span class="n">get_learnable_resizer</span><span class="p">()</span> </code></pre></div> <hr /> <h2 id="visualize-the-outputs-of-the-learnable-resizing-module">Visualize the outputs of the learnable resizing module</h2> <p>Here, we visualize how the resized images would look like after being passed through the random weights of the resizer.</p> <div class="codehilite"><pre><span></span><code><span class="n">sample_images</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="nb">iter</span><span class="p">(</span><span class="n">train_ds</span><span class="p">))</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">16</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">image</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">sample_images</span><span class="p">[:</span><span class="mi">6</span><span class="p">]):</span> <span class="n">image</span> <span class="o">=</span> <span class="n">image</span> <span class="o">/</span> <span class="mi">255</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="s2">"Input Image"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">image</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span><span class="o">.</span><span class="n">squeeze</span><span class="p">())</span> <span class="n">plt</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">"off"</span><span class="p">)</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">i</span> <span class="o">+</span> <span class="mi">2</span><span class="p">)</span> <span class="n">resized_image</span> <span class="o">=</span> <span class="n">learnable_resizer</span><span class="p">(</span><span class="n">image</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="o">...</span><span class="p">])</span> <span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="s2">"Resized Image"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">resized_image</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span><span class="o">.</span><span class="n">squeeze</span><span class="p">())</span> <span class="n">plt</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">"off"</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Corrupt JPEG data: 65 extraneous bytes before marker 0xd9 WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/learnable_resizer/learnable_resizer_13_1.png" /></p> <hr /> <h2 id="model-building-utility">Model building utility</h2> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">get_model</span><span class="p">():</span> <span class="n">backbone</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">applications</span><span class="o">.</span><span class="n">DenseNet121</span><span class="p">(</span> <span class="n">weights</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">include_top</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">classes</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">input_shape</span><span class="o">=</span><span class="p">((</span><span class="n">TARGET_SIZE</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">TARGET_SIZE</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="mi">3</span><span class="p">)),</span> <span class="p">)</span> <span class="n">backbone</span><span class="o">.</span><span class="n">trainable</span> <span class="o">=</span> <span class="kc">True</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Input</span><span class="p">((</span><span class="n">INP_SIZE</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">INP_SIZE</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="mi">3</span><span class="p">))</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Rescaling</span><span class="p">(</span><span class="n">scale</span><span class="o">=</span><span class="mf">1.0</span> <span class="o">/</span> <span class="mi">255</span><span class="p">)(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">learnable_resizer</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">backbone</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="p">)</span> </code></pre></div> <p>The structure of the learnable image resizer module allows for flexible integrations with different vision models.</p> <hr /> <h2 id="compile-and-train-our-model-with-learnable-resizer">Compile and train our model with learnable resizer</h2> <div class="codehilite"><pre><span></span><code><span class="n">model</span> <span class="o">=</span> <span class="n">get_model</span><span class="p">()</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">loss</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">CategoricalCrossentropy</span><span class="p">(</span><span class="n">label_smoothing</span><span class="o">=</span><span class="mf">0.1</span><span class="p">),</span> <span class="n">optimizer</span><span class="o">=</span><span class="s2">"sgd"</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="s2">"accuracy"</span><span class="p">],</span> <span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">train_ds</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="n">validation_ds</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="n">EPOCHS</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 1/5 146/146 ━━━━━━━━━━━━━━━━━━━━ 1790s 12s/step - accuracy: 0.5783 - loss: 0.6877 - val_accuracy: 0.4953 - val_loss: 0.7173 Epoch 2/5 146/146 ━━━━━━━━━━━━━━━━━━━━ 1738s 12s/step - accuracy: 0.6516 - loss: 0.6436 - val_accuracy: 0.6148 - val_loss: 0.6605 Epoch 3/5 146/146 ━━━━━━━━━━━━━━━━━━━━ 1730s 12s/step - accuracy: 0.6881 - loss: 0.6185 - val_accuracy: 0.5529 - val_loss: 0.8655 Epoch 4/5 146/146 ━━━━━━━━━━━━━━━━━━━━ 1725s 12s/step - accuracy: 0.6985 - loss: 0.5980 - val_accuracy: 0.6862 - val_loss: 0.6070 Epoch 5/5 146/146 ━━━━━━━━━━━━━━━━━━━━ 1722s 12s/step - accuracy: 0.7499 - loss: 0.5595 - val_accuracy: 0.6737 - val_loss: 0.6321 <keras.src.callbacks.history.History at 0x7f126c5440a0> </code></pre></div> </div> <hr /> <h2 id="visualize-the-outputs-of-the-trained-visualizer">Visualize the outputs of the trained visualizer</h2> <div class="codehilite"><pre><span></span><code><span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">16</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">image</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">sample_images</span><span class="p">[:</span><span class="mi">6</span><span class="p">]):</span> <span class="n">image</span> <span class="o">=</span> <span class="n">image</span> <span class="o">/</span> <span class="mi">255</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="s2">"Input Image"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">image</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span><span class="o">.</span><span class="n">squeeze</span><span class="p">())</span> <span class="n">plt</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">"off"</span><span class="p">)</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">i</span> <span class="o">+</span> <span class="mi">2</span><span class="p">)</span> <span class="n">resized_image</span> <span class="o">=</span> <span class="n">learnable_resizer</span><span class="p">(</span><span class="n">image</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="o">...</span><span class="p">])</span> <span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="s2">"Resized Image"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">resized_image</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span> <span class="o">/</span> <span class="mi">10</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">"off"</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/learnable_resizer/learnable_resizer_20_1.png" /></p> <p>The plot shows that the visuals of the images have improved with training. The following table shows the benefits of using the resizing module in comparison to using the bilinear interpolation:</p> <table> <thead> <tr> <th style="text-align: center;">Model</th> <th style="text-align: center;">Number of parameters (Million)</th> <th style="text-align: center;">Top-1 accuracy</th> </tr> </thead> <tbody> <tr> <td style="text-align: center;">With the learnable resizer</td> <td style="text-align: center;">7.051717</td> <td style="text-align: center;">67.67%</td> </tr> <tr> <td style="text-align: center;">Without the learnable resizer</td> <td style="text-align: center;">7.039554</td> <td style="text-align: center;">60.19%</td> </tr> </tbody> </table> <p>For more details, you can check out <a href="https://github.com/sayakpaul/Learnable-Image-Resizing">this repository</a>. Note the above-reported models were trained for 10 epochs on 90% of the training set of Cats and Dogs unlike this example. Also, note that the increase in the number of parameters due to the resizing module is very negligible. To ensure that the improvement in the performance is not due to stochasticity, the models were trained using the same initial random weights.</p> <p>Now, a question worth asking here is - <em>isn't the improved accuracy simply a consequence of adding more layers (the resizer is a mini network after all) to the model, compared to the baseline?</em></p> <p>To show that it is not the case, the authors conduct the following experiment:</p> <ul> <li>Take a pre-trained model trained some size, say (224 x 224).</li> </ul> <ul> <li>Now, first, use it to infer predictions on images resized to a lower resolution. Record the performance.</li> </ul> <ul> <li>For the second experiment, plug in the resizer module at the top of the pre-trained model and warm-start the training. Record the performance.</li> </ul> <p>Now, the authors argue that using the second option is better because it helps the model learn how to adjust the representations better with respect to the given resolution. Since the results purely are empirical, a few more experiments such as analyzing the cross-channel interaction would have been even better. It is worth noting that elements like <a href="https://arxiv.org/abs/1709.01507">Squeeze and Excitation (SE) blocks</a>, <a href="https://arxiv.org/abs/1904.11492">Global Context (GC) blocks</a> also add a few parameters to an existing network but they are known to help a network process information in systematic ways to improve the overall performance.</p> <hr /> <h2 id="notes">Notes</h2> <ul> <li>To impose shape bias inside the vision models, Geirhos et al. trained them with a combination of natural and stylized images. It might be interesting to investigate if this learnable resizing module could achieve something similar as the outputs seem to discard the texture information.</li> </ul> <ul> <li>The resizer module can handle arbitrary resolutions and aspect ratios which is very important for tasks like object detection and segmentation.</li> </ul> <ul> <li>There is another closely related topic on <strong><em>adaptive image resizing</em></strong> that attempts to resize images/feature maps adaptively during training. <a href="https://arxiv.org/abs/2104.00298">EfficientV2</a> uses this idea.</li> </ul> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#learning-to-resize-in-computer-vision'>Learning to Resize in Computer Vision</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#setup'>Setup</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#define-hyperparameters'>Define hyperparameters</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#load-and-prepare-the-dataset'>Load and prepare the dataset</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#define-the-learnable-resizer-utilities'>Define the learnable resizer utilities</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#visualize-the-outputs-of-the-learnable-resizing-module'>Visualize the outputs of the learnable resizing module</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#model-building-utility'>Model building utility</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#compile-and-train-our-model-with-learnable-resizer'>Compile and train our model with learnable resizer</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#visualize-the-outputs-of-the-trained-visualizer'>Visualize the outputs of the trained visualizer</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#notes'>Notes</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>