CINXE.COM
Low-light image enhancement using MIRNet
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/vision/mirnet/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Low-light image enhancement using MIRNet"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Low-light image enhancement using MIRNet"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Low-light image enhancement using MIRNet</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink active" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink2" href="/examples/vision/image_classification_from_scratch/">Image classification from scratch</a> <a class="nav-sublink2" href="/examples/vision/mnist_convnet/">Simple MNIST convnet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_efficientnet_fine_tuning/">Image classification via fine-tuning with EfficientNet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_with_vision_transformer/">Image classification with Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/attention_mil_classification/">Classification using Attention-based Deep Multiple Instance Learning</a> <a class="nav-sublink2" href="/examples/vision/mlp_image_classification/">Image classification with modern MLP models</a> <a class="nav-sublink2" href="/examples/vision/mobilevit/">A mobile-friendly Transformer-based model for image classification</a> <a class="nav-sublink2" href="/examples/vision/xray_classification_with_tpus/">Pneumonia Classification on TPU</a> <a class="nav-sublink2" href="/examples/vision/cct/">Compact Convolutional Transformers</a> <a class="nav-sublink2" href="/examples/vision/convmixer/">Image classification with ConvMixer</a> <a class="nav-sublink2" href="/examples/vision/eanet/">Image classification with EANet (External Attention Transformer)</a> <a class="nav-sublink2" href="/examples/vision/involution/">Involutional neural networks</a> <a class="nav-sublink2" href="/examples/vision/perceiver_image_classification/">Image classification with Perceiver</a> <a class="nav-sublink2" href="/examples/vision/reptile/">Few-Shot learning with Reptile</a> <a class="nav-sublink2" href="/examples/vision/semisupervised_simclr/">Semi-supervised image classification using contrastive pretraining with SimCLR</a> <a class="nav-sublink2" href="/examples/vision/swin_transformers/">Image classification with Swin Transformers</a> <a class="nav-sublink2" href="/examples/vision/vit_small_ds/">Train a Vision Transformer on small datasets</a> <a class="nav-sublink2" href="/examples/vision/shiftvit/">A Vision Transformer without Attention</a> <a class="nav-sublink2" href="/examples/vision/image_classification_using_global_context_vision_transformer/">Image Classification using Global Context Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/oxford_pets_image_segmentation/">Image segmentation with a U-Net-like architecture</a> <a class="nav-sublink2" href="/examples/vision/deeplabv3_plus/">Multiclass semantic segmentation using DeepLabV3+</a> <a class="nav-sublink2" href="/examples/vision/basnet_segmentation/">Highly accurate boundaries segmentation using BASNet</a> <a class="nav-sublink2" href="/examples/vision/fully_convolutional_network/">Image Segmentation using Composable Fully-Convolutional Networks</a> <a class="nav-sublink2" href="/examples/vision/retinanet/">Object Detection with RetinaNet</a> <a class="nav-sublink2" href="/examples/vision/keypoint_detection/">Keypoint Detection with Transfer Learning</a> <a class="nav-sublink2" href="/examples/vision/object_detection_using_vision_transformer/">Object detection with Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/3D_image_classification/">3D image classification from CT scans</a> <a class="nav-sublink2" href="/examples/vision/depth_estimation/">Monocular depth estimation</a> <a class="nav-sublink2" href="/examples/vision/nerf/">3D volumetric rendering with NeRF</a> <a class="nav-sublink2" href="/examples/vision/pointnet_segmentation/">Point cloud segmentation with PointNet</a> <a class="nav-sublink2" href="/examples/vision/pointnet/">Point cloud classification</a> <a class="nav-sublink2" href="/examples/vision/captcha_ocr/">OCR model for reading Captchas</a> <a class="nav-sublink2" href="/examples/vision/handwriting_recognition/">Handwriting recognition</a> <a class="nav-sublink2" href="/examples/vision/autoencoder/">Convolutional autoencoder for image denoising</a> <a class="nav-sublink2 active" href="/examples/vision/mirnet/">Low-light image enhancement using MIRNet</a> <a class="nav-sublink2" href="/examples/vision/super_resolution_sub_pixel/">Image Super-Resolution using an Efficient Sub-Pixel CNN</a> <a class="nav-sublink2" href="/examples/vision/edsr/">Enhanced Deep Residual Networks for single-image super-resolution</a> <a class="nav-sublink2" href="/examples/vision/zero_dce/">Zero-DCE for low-light image enhancement</a> <a class="nav-sublink2" href="/examples/vision/cutmix/">CutMix data augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/mixup/">MixUp augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/randaugment/">RandAugment for Image Classification for Improved Robustness</a> <a class="nav-sublink2" href="/examples/vision/image_captioning/">Image captioning</a> <a class="nav-sublink2" href="/examples/vision/nl_image_search/">Natural language image search with a Dual Encoder</a> <a class="nav-sublink2" href="/examples/vision/visualizing_what_convnets_learn/">Visualizing what convnets learn</a> <a class="nav-sublink2" href="/examples/vision/integrated_gradients/">Model interpretability with Integrated Gradients</a> <a class="nav-sublink2" href="/examples/vision/probing_vits/">Investigating Vision Transformer representations</a> <a class="nav-sublink2" href="/examples/vision/grad_cam/">Grad-CAM class activation visualization</a> <a class="nav-sublink2" href="/examples/vision/near_dup_search/">Near-duplicate image search</a> <a class="nav-sublink2" href="/examples/vision/semantic_image_clustering/">Semantic Image Clustering</a> <a class="nav-sublink2" href="/examples/vision/siamese_contrastive/">Image similarity estimation using a Siamese Network with a contrastive loss</a> <a class="nav-sublink2" href="/examples/vision/siamese_network/">Image similarity estimation using a Siamese Network with a triplet loss</a> <a class="nav-sublink2" href="/examples/vision/metric_learning/">Metric learning for image similarity search</a> <a class="nav-sublink2" href="/examples/vision/metric_learning_tf_similarity/">Metric learning for image similarity search using TensorFlow Similarity</a> <a class="nav-sublink2" href="/examples/vision/nnclr/">Self-supervised contrastive learning with NNCLR</a> <a class="nav-sublink2" href="/examples/vision/video_classification/">Video Classification with a CNN-RNN Architecture</a> <a class="nav-sublink2" href="/examples/vision/conv_lstm/">Next-Frame Video Prediction with Convolutional LSTMs</a> <a class="nav-sublink2" href="/examples/vision/video_transformers/">Video Classification with Transformers</a> <a class="nav-sublink2" href="/examples/vision/vivit/">Video Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/bit/">Image Classification using BigTransfer (BiT)</a> <a class="nav-sublink2" href="/examples/vision/gradient_centralization/">Gradient Centralization for Better Training Performance</a> <a class="nav-sublink2" href="/examples/vision/token_learner/">Learning to tokenize in Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/knowledge_distillation/">Knowledge Distillation</a> <a class="nav-sublink2" href="/examples/vision/fixres/">FixRes: Fixing train-test resolution discrepancy</a> <a class="nav-sublink2" href="/examples/vision/cait/">Class Attention Image Transformers with LayerScale</a> <a class="nav-sublink2" href="/examples/vision/patch_convnet/">Augmenting convnets with aggregated attention</a> <a class="nav-sublink2" href="/examples/vision/learnable_resizer/">Learning to Resize</a> <a class="nav-sublink2" href="/examples/vision/adamatch/">Semi-supervision and domain adaptation with AdaMatch</a> <a class="nav-sublink2" href="/examples/vision/barlow_twins/">Barlow Twins for Contrastive SSL</a> <a class="nav-sublink2" href="/examples/vision/consistency_training/">Consistency training with supervision</a> <a class="nav-sublink2" href="/examples/vision/deit/">Distilling Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/focal_modulation_network/">Focal Modulation: A replacement for Self-Attention</a> <a class="nav-sublink2" href="/examples/vision/forwardforward/">Using the Forward-Forward Algorithm for Image Classification</a> <a class="nav-sublink2" href="/examples/vision/masked_image_modeling/">Masked image modeling with Autoencoders</a> <a class="nav-sublink2" href="/examples/vision/sam/">Segment Anything Model with 🤗Transformers</a> <a class="nav-sublink2" href="/examples/vision/segformer/">Semantic segmentation with SegFormer and Hugging Face Transformers</a> <a class="nav-sublink2" href="/examples/vision/simsiam/">Self-supervised contrastive learning with SimSiam</a> <a class="nav-sublink2" href="/examples/vision/supervised-contrastive-learning/">Supervised Contrastive Learning</a> <a class="nav-sublink2" href="/examples/vision/temporal_latent_bottleneck/">When Recurrence meets Transformers</a> <a class="nav-sublink2" href="/examples/vision/yolov8/">Efficient Object Detection with YOLOV8 and KerasCV</a> <a class="nav-sublink" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparameter Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> <a class="nav-link" href="/keras_cv/" role="tab" aria-selected="">KerasCV: Computer Vision Workflows</a> <a class="nav-link" href="/keras_nlp/" role="tab" aria-selected="">KerasNLP: Natural Language Workflows</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/vision/'>Computer Vision</a> / Low-light image enhancement using MIRNet </div> <div class='k-content'> <h1 id="lowlight-image-enhancement-using-mirnet">Low-light image enhancement using MIRNet</h1> <p><strong>Author:</strong> <a href="http://github.com/soumik12345">Soumik Rakshit</a><br> <strong>Date created:</strong> 2021/09/11<br> <strong>Last modified:</strong> 2023/07/15<br> <strong>Description:</strong> Implementing the MIRNet architecture for low-light image enhancement.</p> <div class='example_version_banner keras_3'>ⓘ This example uses Keras 3</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/vision/ipynb/mirnet.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/examples/vision/mirnet.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="introduction">Introduction</h2> <p>With the goal of recovering high-quality image content from its degraded version, image restoration enjoys numerous applications, such as in photography, security, medical imaging, and remote sensing. In this example, we implement the <strong>MIRNet</strong> model for low-light image enhancement, a fully-convolutional architecture that learns an enriched set of features that combines contextual information from multiple scales, while simultaneously preserving the high-resolution spatial details.</p> <h3 id="references">References:</h3> <ul> <li><a href="https://arxiv.org/abs/2003.06792">Learning Enriched Features for Real Image Restoration and Enhancement</a></li> <li><a href="http://www.cnbc.cmu.edu/~tai/cp_papers/E.Land_Retinex_Theory_ScientifcAmerican.pdf">The Retinex Theory of Color Vision</a></li> <li><a href="https://ieeexplore.ieee.org/document/413553">Two deterministic half-quadratic regularization algorithms for computed imaging</a></li> </ul> <hr /> <h2 id="downloading-loldataset">Downloading LOLDataset</h2> <p>The <strong>LoL Dataset</strong> has been created for low-light image enhancement. It provides 485 images for training and 15 for testing. Each image pair in the dataset consists of a low-light input image and its corresponding well-exposed reference image.</p> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">os</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s2">"KERAS_BACKEND"</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"tensorflow"</span> <span class="kn">import</span> <span class="nn">random</span> <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> <span class="kn">from</span> <span class="nn">glob</span> <span class="kn">import</span> <span class="n">glob</span> <span class="kn">from</span> <span class="nn">PIL</span> <span class="kn">import</span> <span class="n">Image</span><span class="p">,</span> <span class="n">ImageOps</span> <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span> <span class="kn">import</span> <span class="nn">keras</span> <span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">layers</span> <span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="nn">tf</span> </code></pre></div> <div class="codehilite"><pre><span></span><code><span class="err">!</span><span class="n">wget</span> <span class="n">https</span><span class="p">:</span><span class="o">//</span><span class="n">huggingface</span><span class="o">.</span><span class="n">co</span><span class="o">/</span><span class="n">datasets</span><span class="o">/</span><span class="n">geekyrakshit</span><span class="o">/</span><span class="n">LoL</span><span class="o">-</span><span class="n">Dataset</span><span class="o">/</span><span class="n">resolve</span><span class="o">/</span><span class="n">main</span><span class="o">/</span><span class="n">lol_dataset</span><span class="o">.</span><span class="n">zip</span> <span class="err">!</span><span class="n">unzip</span> <span class="o">-</span><span class="n">q</span> <span class="n">lol_dataset</span><span class="o">.</span><span class="n">zip</span> <span class="o">&&</span> <span class="n">rm</span> <span class="n">lol_dataset</span><span class="o">.</span><span class="n">zip</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>--2023-11-10 23:10:00-- https://huggingface.co/datasets/geekyrakshit/LoL-Dataset/resolve/main/lol_dataset.zip Resolving huggingface.co (huggingface.co)... 3.163.189.74, 3.163.189.37, 3.163.189.114, ... Connecting to huggingface.co (huggingface.co)|3.163.189.74|:443... connected. HTTP request sent, awaiting response... 302 Found Location: https://cdn-lfs.huggingface.co/repos/d9/09/d909ef7668bb417b7065a311bd55a3084cc83a1f918e13cb41c5503328432db2/419fddc48958cd0f5599939ee0248852a37ceb8bb738c9b9525e95b25a89de9a?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27lol_dataset.zip%3B+filename%3D%22lol_dataset.zip%22%3B&response-content-type=application%2Fzip&Expires=1699917000&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTY5OTkxNzAwMH19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy9kOS8wOS9kOTA5ZWY3NjY4YmI0MTdiNzA2NWEzMTFiZDU1YTMwODRjYzgzYTFmOTE4ZTEzY2I0MWM1NTAzMzI4NDMyZGIyLzQxOWZkZGM0ODk1OGNkMGY1NTk5OTM5ZWUwMjQ4ODUyYTM3Y2ViOGJiNzM4YzliOTUyNWU5NWIyNWE4OWRlOWE%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qJnJlc3BvbnNlLWNvbnRlbnQtdHlwZT0qIn1dfQ__&Signature=xyZ1oUBOnWdy6-vCAFzqZsDMetsPu6OSluyOoTS%7EKRZ6lvAy8yUwQgp5WjcZGJ7Jnex0IdnsPiUzsxaxjM-eZjUcQGPdGj4WhSV5DUBxr8xkwTEospYSg1fX%7EE2I1KkP9gBsXvinsKIOAZzchbg9f28xxdlvTbZ0h4ndcUfbDPknwlU1CIZNa5qjU6NqLMH2bPQmI1AIVau2DgQC%7E1n2dgTZsMfHTVmoM2ivsAl%7E9XgQ3m247ke2aj5BmgssZF52VWKTE-vwYDtbuiem73pS6gS-dZlmXYPE1OSRr2tsDo1cgPEBBtuK3hEnYcOq8jjEZk3AEAbFAJoHKLVIERZ30g__&Key-Pair-Id=KVTP0A1DKRTAX [following] --2023-11-10 23:10:00-- https://cdn-lfs.huggingface.co/repos/d9/09/d909ef7668bb417b7065a311bd55a3084cc83a1f918e13cb41c5503328432db2/419fddc48958cd0f5599939ee0248852a37ceb8bb738c9b9525e95b25a89de9a?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27lol_dataset.zip%3B+filename%3D%22lol_dataset.zip%22%3B&response-content-type=application%2Fzip&Expires=1699917000&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTY5OTkxNzAwMH19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy9kOS8wOS9kOTA5ZWY3NjY4YmI0MTdiNzA2NWEzMTFiZDU1YTMwODRjYzgzYTFmOTE4ZTEzY2I0MWM1NTAzMzI4NDMyZGIyLzQxOWZkZGM0ODk1OGNkMGY1NTk5OTM5ZWUwMjQ4ODUyYTM3Y2ViOGJiNzM4YzliOTUyNWU5NWIyNWE4OWRlOWE%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qJnJlc3BvbnNlLWNvbnRlbnQtdHlwZT0qIn1dfQ__&Signature=xyZ1oUBOnWdy6-vCAFzqZsDMetsPu6OSluyOoTS%7EKRZ6lvAy8yUwQgp5WjcZGJ7Jnex0IdnsPiUzsxaxjM-eZjUcQGPdGj4WhSV5DUBxr8xkwTEospYSg1fX%7EE2I1KkP9gBsXvinsKIOAZzchbg9f28xxdlvTbZ0h4ndcUfbDPknwlU1CIZNa5qjU6NqLMH2bPQmI1AIVau2DgQC%7E1n2dgTZsMfHTVmoM2ivsAl%7E9XgQ3m247ke2aj5BmgssZF52VWKTE-vwYDtbuiem73pS6gS-dZlmXYPE1OSRr2tsDo1cgPEBBtuK3hEnYcOq8jjEZk3AEAbFAJoHKLVIERZ30g__&Key-Pair-Id=KVTP0A1DKRTAX Resolving cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)... 108.138.94.122, 108.138.94.14, 108.138.94.25, ... Connecting to cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)|108.138.94.122|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 347171015 (331M) [application/zip] Saving to: ‘lol_dataset.zip’ </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>lol_dataset.zip 100%[===================>] 331.09M 316MB/s in 1.0s </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>2023-11-10 23:10:01 (316 MB/s) - ‘lol_dataset.zip’ saved [347171015/347171015] </code></pre></div> </div> <hr /> <h2 id="creating-a-tensorflow-dataset">Creating a TensorFlow Dataset</h2> <p>We use 300 image pairs from the LoL Dataset's training set for training, and we use the remaining 185 image pairs for validation. We generate random crops of size <code>128 x 128</code> from the image pairs to be used for both training and validation.</p> <div class="codehilite"><pre><span></span><code><span class="n">random</span><span class="o">.</span><span class="n">seed</span><span class="p">(</span><span class="mi">10</span><span class="p">)</span> <span class="n">IMAGE_SIZE</span> <span class="o">=</span> <span class="mi">128</span> <span class="n">BATCH_SIZE</span> <span class="o">=</span> <span class="mi">4</span> <span class="n">MAX_TRAIN_IMAGES</span> <span class="o">=</span> <span class="mi">300</span> <span class="k">def</span> <span class="nf">read_image</span><span class="p">(</span><span class="n">image_path</span><span class="p">):</span> <span class="n">image</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">io</span><span class="o">.</span><span class="n">read_file</span><span class="p">(</span><span class="n">image_path</span><span class="p">)</span> <span class="n">image</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">decode_png</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="n">channels</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span> <span class="n">image</span><span class="o">.</span><span class="n">set_shape</span><span class="p">([</span><span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="mi">3</span><span class="p">])</span> <span class="n">image</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="o">/</span> <span class="mf">255.0</span> <span class="k">return</span> <span class="n">image</span> <span class="k">def</span> <span class="nf">random_crop</span><span class="p">(</span><span class="n">low_image</span><span class="p">,</span> <span class="n">enhanced_image</span><span class="p">):</span> <span class="n">low_image_shape</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">low_image</span><span class="p">)[:</span><span class="mi">2</span><span class="p">]</span> <span class="n">low_w</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(),</span> <span class="n">maxval</span><span class="o">=</span><span class="n">low_image_shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">-</span> <span class="n">IMAGE_SIZE</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">int32</span> <span class="p">)</span> <span class="n">low_h</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(),</span> <span class="n">maxval</span><span class="o">=</span><span class="n">low_image_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">-</span> <span class="n">IMAGE_SIZE</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">int32</span> <span class="p">)</span> <span class="n">low_image_cropped</span> <span class="o">=</span> <span class="n">low_image</span><span class="p">[</span> <span class="n">low_h</span> <span class="p">:</span> <span class="n">low_h</span> <span class="o">+</span> <span class="n">IMAGE_SIZE</span><span class="p">,</span> <span class="n">low_w</span> <span class="p">:</span> <span class="n">low_w</span> <span class="o">+</span> <span class="n">IMAGE_SIZE</span> <span class="p">]</span> <span class="n">enhanced_image_cropped</span> <span class="o">=</span> <span class="n">enhanced_image</span><span class="p">[</span> <span class="n">low_h</span> <span class="p">:</span> <span class="n">low_h</span> <span class="o">+</span> <span class="n">IMAGE_SIZE</span><span class="p">,</span> <span class="n">low_w</span> <span class="p">:</span> <span class="n">low_w</span> <span class="o">+</span> <span class="n">IMAGE_SIZE</span> <span class="p">]</span> <span class="c1"># in order to avoid `NONE` during shape inference</span> <span class="n">low_image_cropped</span><span class="o">.</span><span class="n">set_shape</span><span class="p">([</span><span class="n">IMAGE_SIZE</span><span class="p">,</span> <span class="n">IMAGE_SIZE</span><span class="p">,</span> <span class="mi">3</span><span class="p">])</span> <span class="n">enhanced_image_cropped</span><span class="o">.</span><span class="n">set_shape</span><span class="p">([</span><span class="n">IMAGE_SIZE</span><span class="p">,</span> <span class="n">IMAGE_SIZE</span><span class="p">,</span> <span class="mi">3</span><span class="p">])</span> <span class="k">return</span> <span class="n">low_image_cropped</span><span class="p">,</span> <span class="n">enhanced_image_cropped</span> <span class="k">def</span> <span class="nf">load_data</span><span class="p">(</span><span class="n">low_light_image_path</span><span class="p">,</span> <span class="n">enhanced_image_path</span><span class="p">):</span> <span class="n">low_light_image</span> <span class="o">=</span> <span class="n">read_image</span><span class="p">(</span><span class="n">low_light_image_path</span><span class="p">)</span> <span class="n">enhanced_image</span> <span class="o">=</span> <span class="n">read_image</span><span class="p">(</span><span class="n">enhanced_image_path</span><span class="p">)</span> <span class="n">low_light_image</span><span class="p">,</span> <span class="n">enhanced_image</span> <span class="o">=</span> <span class="n">random_crop</span><span class="p">(</span><span class="n">low_light_image</span><span class="p">,</span> <span class="n">enhanced_image</span><span class="p">)</span> <span class="k">return</span> <span class="n">low_light_image</span><span class="p">,</span> <span class="n">enhanced_image</span> <span class="k">def</span> <span class="nf">get_dataset</span><span class="p">(</span><span class="n">low_light_images</span><span class="p">,</span> <span class="n">enhanced_images</span><span class="p">):</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">from_tensor_slices</span><span class="p">((</span><span class="n">low_light_images</span><span class="p">,</span> <span class="n">enhanced_images</span><span class="p">))</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">load_data</span><span class="p">,</span> <span class="n">num_parallel_calls</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">AUTOTUNE</span><span class="p">)</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="n">BATCH_SIZE</span><span class="p">,</span> <span class="n">drop_remainder</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="k">return</span> <span class="n">dataset</span> <span class="n">train_low_light_images</span> <span class="o">=</span> <span class="nb">sorted</span><span class="p">(</span><span class="n">glob</span><span class="p">(</span><span class="s2">"./lol_dataset/our485/low/*"</span><span class="p">))[:</span><span class="n">MAX_TRAIN_IMAGES</span><span class="p">]</span> <span class="n">train_enhanced_images</span> <span class="o">=</span> <span class="nb">sorted</span><span class="p">(</span><span class="n">glob</span><span class="p">(</span><span class="s2">"./lol_dataset/our485/high/*"</span><span class="p">))[:</span><span class="n">MAX_TRAIN_IMAGES</span><span class="p">]</span> <span class="n">val_low_light_images</span> <span class="o">=</span> <span class="nb">sorted</span><span class="p">(</span><span class="n">glob</span><span class="p">(</span><span class="s2">"./lol_dataset/our485/low/*"</span><span class="p">))[</span><span class="n">MAX_TRAIN_IMAGES</span><span class="p">:]</span> <span class="n">val_enhanced_images</span> <span class="o">=</span> <span class="nb">sorted</span><span class="p">(</span><span class="n">glob</span><span class="p">(</span><span class="s2">"./lol_dataset/our485/high/*"</span><span class="p">))[</span><span class="n">MAX_TRAIN_IMAGES</span><span class="p">:]</span> <span class="n">test_low_light_images</span> <span class="o">=</span> <span class="nb">sorted</span><span class="p">(</span><span class="n">glob</span><span class="p">(</span><span class="s2">"./lol_dataset/eval15/low/*"</span><span class="p">))</span> <span class="n">test_enhanced_images</span> <span class="o">=</span> <span class="nb">sorted</span><span class="p">(</span><span class="n">glob</span><span class="p">(</span><span class="s2">"./lol_dataset/eval15/high/*"</span><span class="p">))</span> <span class="n">train_dataset</span> <span class="o">=</span> <span class="n">get_dataset</span><span class="p">(</span><span class="n">train_low_light_images</span><span class="p">,</span> <span class="n">train_enhanced_images</span><span class="p">)</span> <span class="n">val_dataset</span> <span class="o">=</span> <span class="n">get_dataset</span><span class="p">(</span><span class="n">val_low_light_images</span><span class="p">,</span> <span class="n">val_enhanced_images</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Train Dataset:"</span><span class="p">,</span> <span class="n">train_dataset</span><span class="o">.</span><span class="n">element_spec</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Val Dataset:"</span><span class="p">,</span> <span class="n">val_dataset</span><span class="o">.</span><span class="n">element_spec</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Train Dataset: (TensorSpec(shape=(4, 128, 128, 3), dtype=tf.float32, name=None), TensorSpec(shape=(4, 128, 128, 3), dtype=tf.float32, name=None)) Val Dataset: (TensorSpec(shape=(4, 128, 128, 3), dtype=tf.float32, name=None), TensorSpec(shape=(4, 128, 128, 3), dtype=tf.float32, name=None)) </code></pre></div> </div> <hr /> <h2 id="mirnet-model">MIRNet Model</h2> <p>Here are the main features of the MIRNet model:</p> <ul> <li>A feature extraction model that computes a complementary set of features across multiple spatial scales, while maintaining the original high-resolution features to preserve precise spatial details.</li> <li>A regularly repeated mechanism for information exchange, where the features across multi-resolution branches are progressively fused together for improved representation learning.</li> <li>A new approach to fuse multi-scale features using a selective kernel network that dynamically combines variable receptive fields and faithfully preserves the original feature information at each spatial resolution.</li> <li>A recursive residual design that progressively breaks down the input signal in order to simplify the overall learning process, and allows the construction of very deep networks.</li> </ul> <p><img alt="" src="https://raw.githubusercontent.com/soumik12345/MIRNet/master/assets/mirnet_architecture.png" /></p> <h3 id="selective-kernel-feature-fusion">Selective Kernel Feature Fusion</h3> <p>The Selective Kernel Feature Fusion or SKFF module performs dynamic adjustment of receptive fields via two operations: <strong>Fuse</strong> and <strong>Select</strong>. The Fuse operator generates global feature descriptors by combining the information from multi-resolution streams. The Select operator uses these descriptors to recalibrate the feature maps (of different streams) followed by their aggregation.</p> <p><strong>Fuse</strong>: The SKFF receives inputs from three parallel convolution streams carrying different scales of information. We first combine these multi-scale features using an element-wise sum, on which we apply Global Average Pooling (GAP) across the spatial dimension. Next, we apply a channel- downscaling convolution layer to generate a compact feature representation which passes through three parallel channel-upscaling convolution layers (one for each resolution stream) and provides us with three feature descriptors.</p> <p><strong>Select</strong>: This operator applies the softmax function to the feature descriptors to obtain the corresponding activations that are used to adaptively recalibrate multi-scale feature maps. The aggregated features are defined as the sum of product of the corresponding multi-scale feature and the feature descriptor.</p> <p><img alt="" src="https://i.imgur.com/7U6ixF6.png" /></p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">selective_kernel_feature_fusion</span><span class="p">(</span> <span class="n">multi_scale_feature_1</span><span class="p">,</span> <span class="n">multi_scale_feature_2</span><span class="p">,</span> <span class="n">multi_scale_feature_3</span> <span class="p">):</span> <span class="n">channels</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">multi_scale_feature_1</span><span class="o">.</span><span class="n">shape</span><span class="p">)[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="n">combined_feature</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Add</span><span class="p">()(</span> <span class="p">[</span><span class="n">multi_scale_feature_1</span><span class="p">,</span> <span class="n">multi_scale_feature_2</span><span class="p">,</span> <span class="n">multi_scale_feature_3</span><span class="p">]</span> <span class="p">)</span> <span class="n">gap</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">GlobalAveragePooling2D</span><span class="p">()(</span><span class="n">combined_feature</span><span class="p">)</span> <span class="n">channel_wise_statistics</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Reshape</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">channels</span><span class="p">))(</span><span class="n">gap</span><span class="p">)</span> <span class="n">compact_feature_representation</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span> <span class="n">filters</span><span class="o">=</span><span class="n">channels</span> <span class="o">//</span> <span class="mi">8</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span> <span class="p">)(</span><span class="n">channel_wise_statistics</span><span class="p">)</span> <span class="n">feature_descriptor_1</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span> <span class="n">channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"softmax"</span> <span class="p">)(</span><span class="n">compact_feature_representation</span><span class="p">)</span> <span class="n">feature_descriptor_2</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span> <span class="n">channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"softmax"</span> <span class="p">)(</span><span class="n">compact_feature_representation</span><span class="p">)</span> <span class="n">feature_descriptor_3</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span> <span class="n">channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"softmax"</span> <span class="p">)(</span><span class="n">compact_feature_representation</span><span class="p">)</span> <span class="n">feature_1</span> <span class="o">=</span> <span class="n">multi_scale_feature_1</span> <span class="o">*</span> <span class="n">feature_descriptor_1</span> <span class="n">feature_2</span> <span class="o">=</span> <span class="n">multi_scale_feature_2</span> <span class="o">*</span> <span class="n">feature_descriptor_2</span> <span class="n">feature_3</span> <span class="o">=</span> <span class="n">multi_scale_feature_3</span> <span class="o">*</span> <span class="n">feature_descriptor_3</span> <span class="n">aggregated_feature</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Add</span><span class="p">()([</span><span class="n">feature_1</span><span class="p">,</span> <span class="n">feature_2</span><span class="p">,</span> <span class="n">feature_3</span><span class="p">])</span> <span class="k">return</span> <span class="n">aggregated_feature</span> </code></pre></div> <h3 id="dual-attention-unit">Dual Attention Unit</h3> <p>The Dual Attention Unit or DAU is used to extract features in the convolutional streams. While the SKFF block fuses information across multi-resolution branches, we also need a mechanism to share information within a feature tensor, both along the spatial and the channel dimensions which is done by the DAU block. The DAU suppresses less useful features and only allows more informative ones to pass further. This feature recalibration is achieved by using <strong>Channel Attention</strong> and <strong>Spatial Attention</strong> mechanisms.</p> <p>The <strong>Channel Attention</strong> branch exploits the inter-channel relationships of the convolutional feature maps by applying squeeze and excitation operations. Given a feature map, the squeeze operation applies Global Average Pooling across spatial dimensions to encode global context, thus yielding a feature descriptor. The excitation operator passes this feature descriptor through two convolutional layers followed by the sigmoid gating and generates activations. Finally, the output of Channel Attention branch is obtained by rescaling the input feature map with the output activations.</p> <p>The <strong>Spatial Attention</strong> branch is designed to exploit the inter-spatial dependencies of convolutional features. The goal of Spatial Attention is to generate a spatial attention map and use it to recalibrate the incoming features. To generate the spatial attention map, the Spatial Attention branch first independently applies Global Average Pooling and Max Pooling operations on input features along the channel dimensions and concatenates the outputs to form a resultant feature map which is then passed through a convolution and sigmoid activation to obtain the spatial attention map. This spatial attention map is then used to rescale the input feature map.</p> <p><img alt="" src="https://i.imgur.com/Dl0IwQs.png" /></p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">ChannelPooling</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">axis</span> <span class="o">=</span> <span class="n">axis</span> <span class="bp">self</span><span class="o">.</span><span class="n">concat</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Concatenate</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">axis</span><span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="n">average_pooling</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">reduce_mean</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">),</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="n">max_pooling</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">reduce_max</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">),</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">concat</span><span class="p">([</span><span class="n">average_pooling</span><span class="p">,</span> <span class="n">max_pooling</span><span class="p">])</span> <span class="k">def</span> <span class="nf">get_config</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="n">config</span> <span class="o">=</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">get_config</span><span class="p">()</span> <span class="n">config</span><span class="o">.</span><span class="n">update</span><span class="p">({</span><span class="s2">"axis"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">axis</span><span class="p">})</span> <span class="k">def</span> <span class="nf">spatial_attention_block</span><span class="p">(</span><span class="n">input_tensor</span><span class="p">):</span> <span class="n">compressed_feature_map</span> <span class="o">=</span> <span class="n">ChannelPooling</span><span class="p">(</span><span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)(</span><span class="n">input_tensor</span><span class="p">)</span> <span class="n">feature_map</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))(</span><span class="n">compressed_feature_map</span><span class="p">)</span> <span class="n">feature_map</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">activations</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">feature_map</span><span class="p">)</span> <span class="k">return</span> <span class="n">input_tensor</span> <span class="o">*</span> <span class="n">feature_map</span> <span class="k">def</span> <span class="nf">channel_attention_block</span><span class="p">(</span><span class="n">input_tensor</span><span class="p">):</span> <span class="n">channels</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">input_tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">)[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="n">average_pooling</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">GlobalAveragePooling2D</span><span class="p">()(</span><span class="n">input_tensor</span><span class="p">)</span> <span class="n">feature_descriptor</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Reshape</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">channels</span><span class="p">))(</span><span class="n">average_pooling</span><span class="p">)</span> <span class="n">feature_activations</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span> <span class="n">filters</span><span class="o">=</span><span class="n">channels</span> <span class="o">//</span> <span class="mi">8</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span> <span class="p">)(</span><span class="n">feature_descriptor</span><span class="p">)</span> <span class="n">feature_activations</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span> <span class="n">filters</span><span class="o">=</span><span class="n">channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"sigmoid"</span> <span class="p">)(</span><span class="n">feature_activations</span><span class="p">)</span> <span class="k">return</span> <span class="n">input_tensor</span> <span class="o">*</span> <span class="n">feature_activations</span> <span class="k">def</span> <span class="nf">dual_attention_unit_block</span><span class="p">(</span><span class="n">input_tensor</span><span class="p">):</span> <span class="n">channels</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">input_tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">)[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="n">feature_map</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span> <span class="n">channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"same"</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span> <span class="p">)(</span><span class="n">input_tensor</span><span class="p">)</span> <span class="n">feature_map</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="n">channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"same"</span><span class="p">)(</span> <span class="n">feature_map</span> <span class="p">)</span> <span class="n">channel_attention</span> <span class="o">=</span> <span class="n">channel_attention_block</span><span class="p">(</span><span class="n">feature_map</span><span class="p">)</span> <span class="n">spatial_attention</span> <span class="o">=</span> <span class="n">spatial_attention_block</span><span class="p">(</span><span class="n">feature_map</span><span class="p">)</span> <span class="n">concatenation</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Concatenate</span><span class="p">(</span><span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)([</span><span class="n">channel_attention</span><span class="p">,</span> <span class="n">spatial_attention</span><span class="p">])</span> <span class="n">concatenation</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="n">channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))(</span><span class="n">concatenation</span><span class="p">)</span> <span class="k">return</span> <span class="n">layers</span><span class="o">.</span><span class="n">Add</span><span class="p">()([</span><span class="n">input_tensor</span><span class="p">,</span> <span class="n">concatenation</span><span class="p">])</span> </code></pre></div> <h3 id="multiscale-residual-block">Multi-Scale Residual Block</h3> <p>The Multi-Scale Residual Block is capable of generating a spatially-precise output by maintaining high-resolution representations, while receiving rich contextual information from low-resolutions. The MRB consists of multiple (three in this paper) fully-convolutional streams connected in parallel. It allows information exchange across parallel streams in order to consolidate the high-resolution features with the help of low-resolution features, and vice versa. The MIRNet employs a recursive residual design (with skip connections) to ease the flow of information during the learning process. In order to maintain the residual nature of our architecture, residual resizing modules are used to perform downsampling and upsampling operations that are used in the Multi-scale Residual Block.</p> <p><img alt="" src="https://i.imgur.com/wzZKV57.png" /></p> <div class="codehilite"><pre><span></span><code><span class="c1"># Recursive Residual Modules</span> <span class="k">def</span> <span class="nf">down_sampling_module</span><span class="p">(</span><span class="n">input_tensor</span><span class="p">):</span> <span class="n">channels</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">input_tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">)[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="n">main_branch</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="n">channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)(</span> <span class="n">input_tensor</span> <span class="p">)</span> <span class="n">main_branch</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span> <span class="n">channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"same"</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span> <span class="p">)(</span><span class="n">main_branch</span><span class="p">)</span> <span class="n">main_branch</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">MaxPooling2D</span><span class="p">()(</span><span class="n">main_branch</span><span class="p">)</span> <span class="n">main_branch</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="n">channels</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))(</span><span class="n">main_branch</span><span class="p">)</span> <span class="n">skip_branch</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">MaxPooling2D</span><span class="p">()(</span><span class="n">input_tensor</span><span class="p">)</span> <span class="n">skip_branch</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="n">channels</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))(</span><span class="n">skip_branch</span><span class="p">)</span> <span class="k">return</span> <span class="n">layers</span><span class="o">.</span><span class="n">Add</span><span class="p">()([</span><span class="n">skip_branch</span><span class="p">,</span> <span class="n">main_branch</span><span class="p">])</span> <span class="k">def</span> <span class="nf">up_sampling_module</span><span class="p">(</span><span class="n">input_tensor</span><span class="p">):</span> <span class="n">channels</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">input_tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">)[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="n">main_branch</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="n">channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)(</span> <span class="n">input_tensor</span> <span class="p">)</span> <span class="n">main_branch</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span> <span class="n">channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"same"</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span> <span class="p">)(</span><span class="n">main_branch</span><span class="p">)</span> <span class="n">main_branch</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">UpSampling2D</span><span class="p">()(</span><span class="n">main_branch</span><span class="p">)</span> <span class="n">main_branch</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="n">channels</span> <span class="o">//</span> <span class="mi">2</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))(</span><span class="n">main_branch</span><span class="p">)</span> <span class="n">skip_branch</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">UpSampling2D</span><span class="p">()(</span><span class="n">input_tensor</span><span class="p">)</span> <span class="n">skip_branch</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="n">channels</span> <span class="o">//</span> <span class="mi">2</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))(</span><span class="n">skip_branch</span><span class="p">)</span> <span class="k">return</span> <span class="n">layers</span><span class="o">.</span><span class="n">Add</span><span class="p">()([</span><span class="n">skip_branch</span><span class="p">,</span> <span class="n">main_branch</span><span class="p">])</span> <span class="c1"># MRB Block</span> <span class="k">def</span> <span class="nf">multi_scale_residual_block</span><span class="p">(</span><span class="n">input_tensor</span><span class="p">,</span> <span class="n">channels</span><span class="p">):</span> <span class="c1"># features</span> <span class="n">level1</span> <span class="o">=</span> <span class="n">input_tensor</span> <span class="n">level2</span> <span class="o">=</span> <span class="n">down_sampling_module</span><span class="p">(</span><span class="n">input_tensor</span><span class="p">)</span> <span class="n">level3</span> <span class="o">=</span> <span class="n">down_sampling_module</span><span class="p">(</span><span class="n">level2</span><span class="p">)</span> <span class="c1"># DAU</span> <span class="n">level1_dau</span> <span class="o">=</span> <span class="n">dual_attention_unit_block</span><span class="p">(</span><span class="n">level1</span><span class="p">)</span> <span class="n">level2_dau</span> <span class="o">=</span> <span class="n">dual_attention_unit_block</span><span class="p">(</span><span class="n">level2</span><span class="p">)</span> <span class="n">level3_dau</span> <span class="o">=</span> <span class="n">dual_attention_unit_block</span><span class="p">(</span><span class="n">level3</span><span class="p">)</span> <span class="c1"># SKFF</span> <span class="n">level1_skff</span> <span class="o">=</span> <span class="n">selective_kernel_feature_fusion</span><span class="p">(</span> <span class="n">level1_dau</span><span class="p">,</span> <span class="n">up_sampling_module</span><span class="p">(</span><span class="n">level2_dau</span><span class="p">),</span> <span class="n">up_sampling_module</span><span class="p">(</span><span class="n">up_sampling_module</span><span class="p">(</span><span class="n">level3_dau</span><span class="p">)),</span> <span class="p">)</span> <span class="n">level2_skff</span> <span class="o">=</span> <span class="n">selective_kernel_feature_fusion</span><span class="p">(</span> <span class="n">down_sampling_module</span><span class="p">(</span><span class="n">level1_dau</span><span class="p">),</span> <span class="n">level2_dau</span><span class="p">,</span> <span class="n">up_sampling_module</span><span class="p">(</span><span class="n">level3_dau</span><span class="p">),</span> <span class="p">)</span> <span class="n">level3_skff</span> <span class="o">=</span> <span class="n">selective_kernel_feature_fusion</span><span class="p">(</span> <span class="n">down_sampling_module</span><span class="p">(</span><span class="n">down_sampling_module</span><span class="p">(</span><span class="n">level1_dau</span><span class="p">)),</span> <span class="n">down_sampling_module</span><span class="p">(</span><span class="n">level2_dau</span><span class="p">),</span> <span class="n">level3_dau</span><span class="p">,</span> <span class="p">)</span> <span class="c1"># DAU 2</span> <span class="n">level1_dau_2</span> <span class="o">=</span> <span class="n">dual_attention_unit_block</span><span class="p">(</span><span class="n">level1_skff</span><span class="p">)</span> <span class="n">level2_dau_2</span> <span class="o">=</span> <span class="n">up_sampling_module</span><span class="p">((</span><span class="n">dual_attention_unit_block</span><span class="p">(</span><span class="n">level2_skff</span><span class="p">)))</span> <span class="n">level3_dau_2</span> <span class="o">=</span> <span class="n">up_sampling_module</span><span class="p">(</span> <span class="n">up_sampling_module</span><span class="p">(</span><span class="n">dual_attention_unit_block</span><span class="p">(</span><span class="n">level3_skff</span><span class="p">))</span> <span class="p">)</span> <span class="c1"># SKFF 2</span> <span class="n">skff_</span> <span class="o">=</span> <span class="n">selective_kernel_feature_fusion</span><span class="p">(</span><span class="n">level1_dau_2</span><span class="p">,</span> <span class="n">level2_dau_2</span><span class="p">,</span> <span class="n">level3_dau_2</span><span class="p">)</span> <span class="n">conv</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="n">channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"same"</span><span class="p">)(</span><span class="n">skff_</span><span class="p">)</span> <span class="k">return</span> <span class="n">layers</span><span class="o">.</span><span class="n">Add</span><span class="p">()([</span><span class="n">input_tensor</span><span class="p">,</span> <span class="n">conv</span><span class="p">])</span> </code></pre></div> <h3 id="mirnet-model">MIRNet Model</h3> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">recursive_residual_group</span><span class="p">(</span><span class="n">input_tensor</span><span class="p">,</span> <span class="n">num_mrb</span><span class="p">,</span> <span class="n">channels</span><span class="p">):</span> <span class="n">conv1</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="n">channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"same"</span><span class="p">)(</span><span class="n">input_tensor</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_mrb</span><span class="p">):</span> <span class="n">conv1</span> <span class="o">=</span> <span class="n">multi_scale_residual_block</span><span class="p">(</span><span class="n">conv1</span><span class="p">,</span> <span class="n">channels</span><span class="p">)</span> <span class="n">conv2</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="n">channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"same"</span><span class="p">)(</span><span class="n">conv1</span><span class="p">)</span> <span class="k">return</span> <span class="n">layers</span><span class="o">.</span><span class="n">Add</span><span class="p">()([</span><span class="n">conv2</span><span class="p">,</span> <span class="n">input_tensor</span><span class="p">])</span> <span class="k">def</span> <span class="nf">mirnet_model</span><span class="p">(</span><span class="n">num_rrg</span><span class="p">,</span> <span class="n">num_mrb</span><span class="p">,</span> <span class="n">channels</span><span class="p">):</span> <span class="n">input_tensor</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="mi">3</span><span class="p">])</span> <span class="n">x1</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="n">channels</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"same"</span><span class="p">)(</span><span class="n">input_tensor</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_rrg</span><span class="p">):</span> <span class="n">x1</span> <span class="o">=</span> <span class="n">recursive_residual_group</span><span class="p">(</span><span class="n">x1</span><span class="p">,</span> <span class="n">num_mrb</span><span class="p">,</span> <span class="n">channels</span><span class="p">)</span> <span class="n">conv</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"same"</span><span class="p">)(</span><span class="n">x1</span><span class="p">)</span> <span class="n">output_tensor</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Add</span><span class="p">()([</span><span class="n">input_tensor</span><span class="p">,</span> <span class="n">conv</span><span class="p">])</span> <span class="k">return</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">input_tensor</span><span class="p">,</span> <span class="n">output_tensor</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">mirnet_model</span><span class="p">(</span><span class="n">num_rrg</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">num_mrb</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">channels</span><span class="o">=</span><span class="mi">64</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="training">Training</h2> <ul> <li>We train MIRNet using <strong>Charbonnier Loss</strong> as the loss function and <strong>Adam Optimizer</strong> with a learning rate of <code>1e-4</code>.</li> <li>We use <strong>Peak Signal Noise Ratio</strong> or PSNR as a metric which is an expression for the ratio between the maximum possible value (power) of a signal and the power of distorting noise that affects the quality of its representation.</li> </ul> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">charbonnier_loss</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">):</span> <span class="k">return</span> <span class="n">tf</span><span class="o">.</span><span class="n">reduce_mean</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">square</span><span class="p">(</span><span class="n">y_true</span> <span class="o">-</span> <span class="n">y_pred</span><span class="p">)</span> <span class="o">+</span> <span class="n">tf</span><span class="o">.</span><span class="n">square</span><span class="p">(</span><span class="mf">1e-3</span><span class="p">)))</span> <span class="k">def</span> <span class="nf">peak_signal_noise_ratio</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">):</span> <span class="k">return</span> <span class="n">tf</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">psnr</span><span class="p">(</span><span class="n">y_pred</span><span class="p">,</span> <span class="n">y_true</span><span class="p">,</span> <span class="n">max_val</span><span class="o">=</span><span class="mf">255.0</span><span class="p">)</span> <span class="n">optimizer</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">learning_rate</span><span class="o">=</span><span class="mf">1e-4</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">optimizer</span><span class="p">,</span> <span class="n">loss</span><span class="o">=</span><span class="n">charbonnier_loss</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="n">peak_signal_noise_ratio</span><span class="p">],</span> <span class="p">)</span> <span class="n">history</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span> <span class="n">train_dataset</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="n">val_dataset</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">50</span><span class="p">,</span> <span class="n">callbacks</span><span class="o">=</span><span class="p">[</span> <span class="n">keras</span><span class="o">.</span><span class="n">callbacks</span><span class="o">.</span><span class="n">ReduceLROnPlateau</span><span class="p">(</span> <span class="n">monitor</span><span class="o">=</span><span class="s2">"val_peak_signal_noise_ratio"</span><span class="p">,</span> <span class="n">factor</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span> <span class="n">patience</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">verbose</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">min_delta</span><span class="o">=</span><span class="mf">1e-7</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s2">"max"</span><span class="p">,</span> <span class="p">)</span> <span class="p">],</span> <span class="p">)</span> <span class="k">def</span> <span class="nf">plot_history</span><span class="p">(</span><span class="n">value</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span> <span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">history</span><span class="o">.</span><span class="n">history</span><span class="p">[</span><span class="n">value</span><span class="p">],</span> <span class="n">label</span><span class="o">=</span><span class="sa">f</span><span class="s2">"train_</span><span class="si">{</span><span class="n">name</span><span class="o">.</span><span class="n">lower</span><span class="p">()</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">history</span><span class="o">.</span><span class="n">history</span><span class="p">[</span><span class="sa">f</span><span class="s2">"val_</span><span class="si">{</span><span class="n">value</span><span class="si">}</span><span class="s2">"</span><span class="p">],</span> <span class="n">label</span><span class="o">=</span><span class="sa">f</span><span class="s2">"val_</span><span class="si">{</span><span class="n">name</span><span class="o">.</span><span class="n">lower</span><span class="p">()</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s2">"Epochs"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">ylabel</span><span class="p">(</span><span class="n">name</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Train and Validation </span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2"> Over Epochs"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">14</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">legend</span><span class="p">()</span> <span class="n">plt</span><span class="o">.</span><span class="n">grid</span><span class="p">()</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> <span class="n">plot_history</span><span class="p">(</span><span class="s2">"loss"</span><span class="p">,</span> <span class="s2">"Loss"</span><span class="p">)</span> <span class="n">plot_history</span><span class="p">(</span><span class="s2">"peak_signal_noise_ratio"</span><span class="p">,</span> <span class="s2">"PSNR"</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 1/50 WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1699658204.480352 77759 device_compiler.h:187] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process. 75/75 ━━━━━━━━━━━━━━━━━━━━ 445s 686ms/step - loss: 0.2162 - peak_signal_noise_ratio: 61.5549 - val_loss: 0.1358 - val_peak_signal_noise_ratio: 65.2699 - learning_rate: 1.0000e-04 Epoch 2/50 75/75 ━━━━━━━━━━━━━━━━━━━━ 29s 385ms/step - loss: 0.1745 - peak_signal_noise_ratio: 63.1785 - val_loss: 0.1237 - val_peak_signal_noise_ratio: 65.8360 - learning_rate: 1.0000e-04 Epoch 3/50 75/75 ━━━━━━━━━━━━━━━━━━━━ 29s 386ms/step - loss: 0.1681 - peak_signal_noise_ratio: 63.4903 - val_loss: 0.1205 - val_peak_signal_noise_ratio: 65.9048 - learning_rate: 1.0000e-04 Epoch 4/50 75/75 ━━━━━━━━━━━━━━━━━━━━ 29s 385ms/step - loss: 0.1668 - peak_signal_noise_ratio: 63.4793 - val_loss: 0.1185 - val_peak_signal_noise_ratio: 66.0290 - learning_rate: 1.0000e-04 Epoch 5/50 75/75 ━━━━━━━━━━━━━━━━━━━━ 29s 383ms/step - loss: 0.1564 - peak_signal_noise_ratio: 63.9205 - val_loss: 0.1217 - val_peak_signal_noise_ratio: 66.1207 - learning_rate: 1.0000e-04 Epoch 6/50 75/75 ━━━━━━━━━━━━━━━━━━━━ 29s 384ms/step - loss: 0.1601 - peak_signal_noise_ratio: 63.9336 - val_loss: 0.1166 - val_peak_signal_noise_ratio: 66.6102 - learning_rate: 1.0000e-04 Epoch 7/50 75/75 ━━━━━━━━━━━━━━━━━━━━ 29s 385ms/step - loss: 0.1600 - peak_signal_noise_ratio: 63.9043 - val_loss: 0.1335 - val_peak_signal_noise_ratio: 65.5639 - learning_rate: 1.0000e-04 Epoch 8/50 75/75 ━━━━━━━━━━━━━━━━━━━━ 29s 382ms/step - loss: 0.1609 - peak_signal_noise_ratio: 64.0606 - val_loss: 0.1135 - val_peak_signal_noise_ratio: 66.9369 - learning_rate: 1.0000e-04 Epoch 9/50 75/75 ━━━━━━━━━━━━━━━━━━━━ 29s 384ms/step - loss: 0.1539 - peak_signal_noise_ratio: 64.3915 - val_loss: 0.1165 - val_peak_signal_noise_ratio: 66.9783 - learning_rate: 1.0000e-04 Epoch 10/50 75/75 ━━━━━━━━━━━━━━━━━━━━ 43s 409ms/step - loss: 0.1536 - peak_signal_noise_ratio: 64.4491 - val_loss: 0.1118 - val_peak_signal_noise_ratio: 66.8747 - learning_rate: 1.0000e-04 Epoch 11/50 75/75 ━━━━━━━━━━━━━━━━━━━━ 29s 383ms/step - loss: 0.1449 - peak_signal_noise_ratio: 64.6579 - val_loss: 0.1167 - val_peak_signal_noise_ratio: 66.9626 - learning_rate: 1.0000e-04 Epoch 12/50 75/75 ━━━━━━━━━━━━━━━━━━━━ 29s 383ms/step - loss: 0.1501 - peak_signal_noise_ratio: 64.7929 - val_loss: 0.1143 - val_peak_signal_noise_ratio: 66.9400 - learning_rate: 1.0000e-04 Epoch 13/50 75/75 ━━━━━━━━━━━━━━━━━━━━ 29s 380ms/step - loss: 0.1510 - peak_signal_noise_ratio: 64.6816 - val_loss: 0.1302 - val_peak_signal_noise_ratio: 66.0576 - learning_rate: 1.0000e-04 Epoch 14/50 75/75 ━━━━━━━━━━━━━━━━━━━━ 29s 383ms/step - loss: 0.1632 - peak_signal_noise_ratio: 63.9234 - val_loss: 0.1146 - val_peak_signal_noise_ratio: 67.0321 - learning_rate: 1.0000e-04 Epoch 15/50 75/75 ━━━━━━━━━━━━━━━━━━━━ 28s 379ms/step - loss: 0.1486 - peak_signal_noise_ratio: 64.7125 - val_loss: 0.1284 - val_peak_signal_noise_ratio: 66.2105 - learning_rate: 1.0000e-04 Epoch 16/50 75/75 ━━━━━━━━━━━━━━━━━━━━ 28s 379ms/step - loss: 0.1482 - peak_signal_noise_ratio: 64.8123 - val_loss: 0.1176 - val_peak_signal_noise_ratio: 66.8114 - learning_rate: 1.0000e-04 Epoch 17/50 75/75 ━━━━━━━━━━━━━━━━━━━━ 29s 381ms/step - loss: 0.1459 - peak_signal_noise_ratio: 64.7795 - val_loss: 0.1092 - val_peak_signal_noise_ratio: 67.4173 - learning_rate: 1.0000e-04 Epoch 18/50 75/75 ━━━━━━━━━━━━━━━━━━━━ 28s 378ms/step - loss: 0.1482 - peak_signal_noise_ratio: 64.8821 - val_loss: 0.1175 - val_peak_signal_noise_ratio: 67.0296 - learning_rate: 1.0000e-04 Epoch 19/50 75/75 ━━━━━━━━━━━━━━━━━━━━ 29s 381ms/step - loss: 0.1524 - peak_signal_noise_ratio: 64.7275 - val_loss: 0.1028 - val_peak_signal_noise_ratio: 67.8485 - learning_rate: 1.0000e-04 Epoch 20/50 75/75 ━━━━━━━━━━━━━━━━━━━━ 28s 379ms/step - loss: 0.1350 - peak_signal_noise_ratio: 65.6166 - val_loss: 0.1040 - val_peak_signal_noise_ratio: 67.8551 - learning_rate: 1.0000e-04 Epoch 21/50 75/75 ━━━━━━━━━━━━━━━━━━━━ 29s 380ms/step - loss: 0.1383 - peak_signal_noise_ratio: 65.5167 - val_loss: 0.1071 - val_peak_signal_noise_ratio: 67.5902 - learning_rate: 1.0000e-04 Epoch 22/50 75/75 ━━━━━━━━━━━━━━━━━━━━ 28s 379ms/step - loss: 0.1393 - peak_signal_noise_ratio: 65.6293 - val_loss: 0.1096 - val_peak_signal_noise_ratio: 67.2940 - learning_rate: 1.0000e-04 Epoch 23/50 75/75 ━━━━━━━━━━━━━━━━━━━━ 29s 383ms/step - loss: 0.1399 - peak_signal_noise_ratio: 65.5146 - val_loss: 0.1044 - val_peak_signal_noise_ratio: 67.6932 - learning_rate: 1.0000e-04 Epoch 24/50 75/75 ━━━━━━━━━━━━━━━━━━━━ 28s 378ms/step - loss: 0.1390 - peak_signal_noise_ratio: 65.7525 - val_loss: 0.1135 - val_peak_signal_noise_ratio: 66.9891 - learning_rate: 1.0000e-04 Epoch 25/50 75/75 ━━━━━━━━━━━━━━━━━━━━ 0s 326ms/step - loss: 0.1333 - peak_signal_noise_ratio: 65.8340 Epoch 25: ReduceLROnPlateau reducing learning rate to 4.999999873689376e-05. 75/75 ━━━━━━━━━━━━━━━━━━━━ 28s 380ms/step - loss: 0.1332 - peak_signal_noise_ratio: 65.8348 - val_loss: 0.1252 - val_peak_signal_noise_ratio: 66.5684 - learning_rate: 1.0000e-04 Epoch 26/50 75/75 ━━━━━━━━━━━━━━━━━━━━ 29s 381ms/step - loss: 0.1547 - peak_signal_noise_ratio: 64.8968 - val_loss: 0.1105 - val_peak_signal_noise_ratio: 67.0688 - learning_rate: 5.0000e-05 Epoch 27/50 75/75 ━━━━━━━━━━━━━━━━━━━━ 28s 380ms/step - loss: 0.1269 - peak_signal_noise_ratio: 66.3882 - val_loss: 0.1035 - val_peak_signal_noise_ratio: 67.7006 - learning_rate: 5.0000e-05 Epoch 28/50 75/75 ━━━━━━━━━━━━━━━━━━━━ 30s 405ms/step - loss: 0.1243 - peak_signal_noise_ratio: 66.5826 - val_loss: 0.1063 - val_peak_signal_noise_ratio: 67.2497 - learning_rate: 5.0000e-05 Epoch 29/50 75/75 ━━━━━━━━━━━━━━━━━━━━ 29s 383ms/step - loss: 0.1292 - peak_signal_noise_ratio: 66.1734 - val_loss: 0.1064 - val_peak_signal_noise_ratio: 67.3989 - learning_rate: 5.0000e-05 Epoch 30/50 75/75 ━━━━━━━━━━━━━━━━━━━━ 0s 328ms/step - loss: 0.1304 - peak_signal_noise_ratio: 66.1267 Epoch 30: ReduceLROnPlateau reducing learning rate to 2.499999936844688e-05. 75/75 ━━━━━━━━━━━━━━━━━━━━ 29s 382ms/step - loss: 0.1304 - peak_signal_noise_ratio: 66.1294 - val_loss: 0.1109 - val_peak_signal_noise_ratio: 66.8935 - learning_rate: 5.0000e-05 Epoch 31/50 75/75 ━━━━━━━━━━━━━━━━━━━━ 29s 381ms/step - loss: 0.1141 - peak_signal_noise_ratio: 67.1338 - val_loss: 0.1145 - val_peak_signal_noise_ratio: 66.8367 - learning_rate: 2.5000e-05 Epoch 32/50 75/75 ━━━━━━━━━━━━━━━━━━━━ 28s 380ms/step - loss: 0.1141 - peak_signal_noise_ratio: 66.9369 - val_loss: 0.1132 - val_peak_signal_noise_ratio: 66.9264 - learning_rate: 2.5000e-05 Epoch 33/50 75/75 ━━━━━━━━━━━━━━━━━━━━ 28s 380ms/step - loss: 0.1184 - peak_signal_noise_ratio: 66.7723 - val_loss: 0.1090 - val_peak_signal_noise_ratio: 67.1115 - learning_rate: 2.5000e-05 Epoch 34/50 75/75 ━━━━━━━━━━━━━━━━━━━━ 29s 380ms/step - loss: 0.1243 - peak_signal_noise_ratio: 66.4147 - val_loss: 0.1080 - val_peak_signal_noise_ratio: 67.2300 - learning_rate: 2.5000e-05 Epoch 35/50 75/75 ━━━━━━━━━━━━━━━━━━━━ 0s 325ms/step - loss: 0.1230 - peak_signal_noise_ratio: 66.7113 Epoch 35: ReduceLROnPlateau reducing learning rate to 1.249999968422344e-05. 75/75 ━━━━━━━━━━━━━━━━━━━━ 29s 381ms/step - loss: 0.1229 - peak_signal_noise_ratio: 66.7121 - val_loss: 0.1038 - val_peak_signal_noise_ratio: 67.5288 - learning_rate: 2.5000e-05 Epoch 36/50 75/75 ━━━━━━━━━━━━━━━━━━━━ 28s 380ms/step - loss: 0.1181 - peak_signal_noise_ratio: 66.9202 - val_loss: 0.1030 - val_peak_signal_noise_ratio: 67.6249 - learning_rate: 1.2500e-05 Epoch 37/50 75/75 ━━━━━━━━━━━━━━━━━━━━ 28s 380ms/step - loss: 0.1086 - peak_signal_noise_ratio: 67.5034 - val_loss: 0.1016 - val_peak_signal_noise_ratio: 67.6940 - learning_rate: 1.2500e-05 Epoch 38/50 75/75 ━━━━━━━━━━━━━━━━━━━━ 28s 380ms/step - loss: 0.1127 - peak_signal_noise_ratio: 67.3735 - val_loss: 0.1004 - val_peak_signal_noise_ratio: 68.0042 - learning_rate: 1.2500e-05 Epoch 39/50 75/75 ━━━━━━━━━━━━━━━━━━━━ 28s 379ms/step - loss: 0.1135 - peak_signal_noise_ratio: 67.3436 - val_loss: 0.1150 - val_peak_signal_noise_ratio: 66.9541 - learning_rate: 1.2500e-05 Epoch 40/50 75/75 ━━━━━━━━━━━━━━━━━━━━ 29s 381ms/step - loss: 0.1152 - peak_signal_noise_ratio: 67.1675 - val_loss: 0.1093 - val_peak_signal_noise_ratio: 67.2030 - learning_rate: 1.2500e-05 Epoch 41/50 75/75 ━━━━━━━━━━━━━━━━━━━━ 28s 378ms/step - loss: 0.1191 - peak_signal_noise_ratio: 66.7586 - val_loss: 0.1095 - val_peak_signal_noise_ratio: 67.2424 - learning_rate: 1.2500e-05 Epoch 42/50 75/75 ━━━━━━━━━━━━━━━━━━━━ 30s 405ms/step - loss: 0.1062 - peak_signal_noise_ratio: 67.6856 - val_loss: 0.1092 - val_peak_signal_noise_ratio: 67.2187 - learning_rate: 1.2500e-05 Epoch 43/50 75/75 ━━━━━━━━━━━━━━━━━━━━ 0s 323ms/step - loss: 0.1099 - peak_signal_noise_ratio: 67.6400 Epoch 43: ReduceLROnPlateau reducing learning rate to 6.24999984211172e-06. 75/75 ━━━━━━━━━━━━━━━━━━━━ 28s 377ms/step - loss: 0.1099 - peak_signal_noise_ratio: 67.6378 - val_loss: 0.1079 - val_peak_signal_noise_ratio: 67.4591 - learning_rate: 1.2500e-05 Epoch 44/50 75/75 ━━━━━━━━━━━━━━━━━━━━ 28s 378ms/step - loss: 0.1155 - peak_signal_noise_ratio: 67.0911 - val_loss: 0.1019 - val_peak_signal_noise_ratio: 67.8073 - learning_rate: 6.2500e-06 Epoch 45/50 75/75 ━━━━━━━━━━━━━━━━━━━━ 28s 377ms/step - loss: 0.1145 - peak_signal_noise_ratio: 67.1876 - val_loss: 0.1067 - val_peak_signal_noise_ratio: 67.4283 - learning_rate: 6.2500e-06 Epoch 46/50 75/75 ━━━━━━━━━━━━━━━━━━━━ 29s 384ms/step - loss: 0.1077 - peak_signal_noise_ratio: 67.7168 - val_loss: 0.1114 - val_peak_signal_noise_ratio: 67.1392 - learning_rate: 6.2500e-06 Epoch 47/50 75/75 ━━━━━━━━━━━━━━━━━━━━ 28s 377ms/step - loss: 0.1117 - peak_signal_noise_ratio: 67.3210 - val_loss: 0.1081 - val_peak_signal_noise_ratio: 67.3622 - learning_rate: 6.2500e-06 Epoch 48/50 75/75 ━━━━━━━━━━━━━━━━━━━━ 0s 326ms/step - loss: 0.1074 - peak_signal_noise_ratio: 67.7986 Epoch 48: ReduceLROnPlateau reducing learning rate to 3.12499992105586e-06. 75/75 ━━━━━━━━━━━━━━━━━━━━ 29s 380ms/step - loss: 0.1074 - peak_signal_noise_ratio: 67.7992 - val_loss: 0.1101 - val_peak_signal_noise_ratio: 67.3376 - learning_rate: 6.2500e-06 Epoch 49/50 75/75 ━━━━━━━━━━━━━━━━━━━━ 28s 380ms/step - loss: 0.1081 - peak_signal_noise_ratio: 67.5032 - val_loss: 0.1121 - val_peak_signal_noise_ratio: 67.0685 - learning_rate: 3.1250e-06 Epoch 50/50 75/75 ━━━━━━━━━━━━━━━━━━━━ 28s 378ms/step - loss: 0.1077 - peak_signal_noise_ratio: 67.6709 - val_loss: 0.1084 - val_peak_signal_noise_ratio: 67.6183 - learning_rate: 3.1250e-06 </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/mirnet/mirnet_17_3.png" /></p> <p><img alt="png" src="/img/examples/vision/mirnet/mirnet_17_4.png" /></p> <hr /> <h2 id="inference">Inference</h2> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">plot_results</span><span class="p">(</span><span class="n">images</span><span class="p">,</span> <span class="n">titles</span><span class="p">,</span> <span class="n">figure_size</span><span class="o">=</span><span class="p">(</span><span class="mi">12</span><span class="p">,</span> <span class="mi">12</span><span class="p">)):</span> <span class="n">fig</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="n">figure_size</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">images</span><span class="p">)):</span> <span class="n">fig</span><span class="o">.</span><span class="n">add_subplot</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">images</span><span class="p">),</span> <span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="n">titles</span><span class="p">[</span><span class="n">i</span><span class="p">])</span> <span class="n">_</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">images</span><span class="p">[</span><span class="n">i</span><span class="p">])</span> <span class="n">plt</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">"off"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> <span class="k">def</span> <span class="nf">infer</span><span class="p">(</span><span class="n">original_image</span><span class="p">):</span> <span class="n">image</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">img_to_array</span><span class="p">(</span><span class="n">original_image</span><span class="p">)</span> <span class="n">image</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">"float32"</span><span class="p">)</span> <span class="o">/</span> <span class="mf">255.0</span> <span class="n">image</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="n">output</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="n">verbose</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="n">output_image</span> <span class="o">=</span> <span class="n">output</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="mf">255.0</span> <span class="n">output_image</span> <span class="o">=</span> <span class="n">output_image</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">255</span><span class="p">)</span> <span class="n">output_image</span> <span class="o">=</span> <span class="n">output_image</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span> <span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">output_image</span><span class="p">)[</span><span class="mi">0</span><span class="p">],</span> <span class="n">np</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">output_image</span><span class="p">)[</span><span class="mi">1</span><span class="p">],</span> <span class="mi">3</span><span class="p">)</span> <span class="p">)</span> <span class="n">output_image</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">fromarray</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">(</span><span class="n">output_image</span><span class="p">))</span> <span class="n">original_image</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">fromarray</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">(</span><span class="n">original_image</span><span class="p">))</span> <span class="k">return</span> <span class="n">output_image</span> </code></pre></div> <h3 id="inference-on-test-images">Inference on Test Images</h3> <p>We compare the test images from LOLDataset enhanced by MIRNet with images enhanced via the <code>PIL.ImageOps.autocontrast()</code> function.</p> <p>You can use the trained model hosted on <a href="https://huggingface.co/keras-io/lowlight-enhance-mirnet">Hugging Face Hub</a> and try the demo on <a href="https://huggingface.co/spaces/keras-io/Enhance_Low_Light_Image">Hugging Face Spaces</a>.</p> <div class="codehilite"><pre><span></span><code><span class="k">for</span> <span class="n">low_light_image</span> <span class="ow">in</span> <span class="n">random</span><span class="o">.</span><span class="n">sample</span><span class="p">(</span><span class="n">test_low_light_images</span><span class="p">,</span> <span class="mi">6</span><span class="p">):</span> <span class="n">original_image</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="n">low_light_image</span><span class="p">)</span> <span class="n">enhanced_image</span> <span class="o">=</span> <span class="n">infer</span><span class="p">(</span><span class="n">original_image</span><span class="p">)</span> <span class="n">plot_results</span><span class="p">(</span> <span class="p">[</span><span class="n">original_image</span><span class="p">,</span> <span class="n">ImageOps</span><span class="o">.</span><span class="n">autocontrast</span><span class="p">(</span><span class="n">original_image</span><span class="p">),</span> <span class="n">enhanced_image</span><span class="p">],</span> <span class="p">[</span><span class="s2">"Original"</span><span class="p">,</span> <span class="s2">"PIL Autocontrast"</span><span class="p">,</span> <span class="s2">"MIRNet Enhanced"</span><span class="p">],</span> <span class="p">(</span><span class="mi">20</span><span class="p">,</span> <span class="mi">12</span><span class="p">),</span> <span class="p">)</span> </code></pre></div> <p><img alt="png" src="/img/examples/vision/mirnet/mirnet_21_0.png" /></p> <p><img alt="png" src="/img/examples/vision/mirnet/mirnet_21_1.png" /></p> <p><img alt="png" src="/img/examples/vision/mirnet/mirnet_21_2.png" /></p> <p><img alt="png" src="/img/examples/vision/mirnet/mirnet_21_3.png" /></p> <p><img alt="png" src="/img/examples/vision/mirnet/mirnet_21_4.png" /></p> <p><img alt="png" src="/img/examples/vision/mirnet/mirnet_21_5.png" /></p> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#lowlight-image-enhancement-using-mirnet'>Low-light image enhancement using MIRNet</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-3'> <a href='#references'>References:</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#downloading-loldataset'>Downloading LOLDataset</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#creating-a-tensorflow-dataset'>Creating a TensorFlow Dataset</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#mirnet-model'>MIRNet Model</a> </div> <div class='k-outline-depth-3'> <a href='#selective-kernel-feature-fusion'>Selective Kernel Feature Fusion</a> </div> <div class='k-outline-depth-3'> <a href='#dual-attention-unit'>Dual Attention Unit</a> </div> <div class='k-outline-depth-3'> <a href='#multiscale-residual-block'>Multi-Scale Residual Block</a> </div> <div class='k-outline-depth-3'> <a href='#mirnet-model'>MIRNet Model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#training'>Training</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#inference'>Inference</a> </div> <div class='k-outline-depth-3'> <a href='#inference-on-test-images'>Inference on Test Images</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>