CINXE.COM
Visualizing what convnets learn
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/vision/visualizing_what_convnets_learn/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Visualizing what convnets learn"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Visualizing what convnets learn"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Visualizing what convnets learn</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink active" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink2" href="/examples/vision/image_classification_from_scratch/">Image classification from scratch</a> <a class="nav-sublink2" href="/examples/vision/mnist_convnet/">Simple MNIST convnet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_efficientnet_fine_tuning/">Image classification via fine-tuning with EfficientNet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_with_vision_transformer/">Image classification with Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/attention_mil_classification/">Classification using Attention-based Deep Multiple Instance Learning</a> <a class="nav-sublink2" href="/examples/vision/mlp_image_classification/">Image classification with modern MLP models</a> <a class="nav-sublink2" href="/examples/vision/mobilevit/">A mobile-friendly Transformer-based model for image classification</a> <a class="nav-sublink2" href="/examples/vision/xray_classification_with_tpus/">Pneumonia Classification on TPU</a> <a class="nav-sublink2" href="/examples/vision/cct/">Compact Convolutional Transformers</a> <a class="nav-sublink2" href="/examples/vision/convmixer/">Image classification with ConvMixer</a> <a class="nav-sublink2" href="/examples/vision/eanet/">Image classification with EANet (External Attention Transformer)</a> <a class="nav-sublink2" href="/examples/vision/involution/">Involutional neural networks</a> <a class="nav-sublink2" href="/examples/vision/perceiver_image_classification/">Image classification with Perceiver</a> <a class="nav-sublink2" href="/examples/vision/reptile/">Few-Shot learning with Reptile</a> <a class="nav-sublink2" href="/examples/vision/semisupervised_simclr/">Semi-supervised image classification using contrastive pretraining with SimCLR</a> <a class="nav-sublink2" href="/examples/vision/swin_transformers/">Image classification with Swin Transformers</a> <a class="nav-sublink2" href="/examples/vision/vit_small_ds/">Train a Vision Transformer on small datasets</a> <a class="nav-sublink2" href="/examples/vision/shiftvit/">A Vision Transformer without Attention</a> <a class="nav-sublink2" href="/examples/vision/image_classification_using_global_context_vision_transformer/">Image Classification using Global Context Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/oxford_pets_image_segmentation/">Image segmentation with a U-Net-like architecture</a> <a class="nav-sublink2" href="/examples/vision/deeplabv3_plus/">Multiclass semantic segmentation using DeepLabV3+</a> <a class="nav-sublink2" href="/examples/vision/basnet_segmentation/">Highly accurate boundaries segmentation using BASNet</a> <a class="nav-sublink2" href="/examples/vision/fully_convolutional_network/">Image Segmentation using Composable Fully-Convolutional Networks</a> <a class="nav-sublink2" href="/examples/vision/retinanet/">Object Detection with RetinaNet</a> <a class="nav-sublink2" href="/examples/vision/keypoint_detection/">Keypoint Detection with Transfer Learning</a> <a class="nav-sublink2" href="/examples/vision/object_detection_using_vision_transformer/">Object detection with Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/3D_image_classification/">3D image classification from CT scans</a> <a class="nav-sublink2" href="/examples/vision/depth_estimation/">Monocular depth estimation</a> <a class="nav-sublink2" href="/examples/vision/nerf/">3D volumetric rendering with NeRF</a> <a class="nav-sublink2" href="/examples/vision/pointnet_segmentation/">Point cloud segmentation with PointNet</a> <a class="nav-sublink2" href="/examples/vision/pointnet/">Point cloud classification</a> <a class="nav-sublink2" href="/examples/vision/captcha_ocr/">OCR model for reading Captchas</a> <a class="nav-sublink2" href="/examples/vision/handwriting_recognition/">Handwriting recognition</a> <a class="nav-sublink2" href="/examples/vision/autoencoder/">Convolutional autoencoder for image denoising</a> <a class="nav-sublink2" href="/examples/vision/mirnet/">Low-light image enhancement using MIRNet</a> <a class="nav-sublink2" href="/examples/vision/super_resolution_sub_pixel/">Image Super-Resolution using an Efficient Sub-Pixel CNN</a> <a class="nav-sublink2" href="/examples/vision/edsr/">Enhanced Deep Residual Networks for single-image super-resolution</a> <a class="nav-sublink2" href="/examples/vision/zero_dce/">Zero-DCE for low-light image enhancement</a> <a class="nav-sublink2" href="/examples/vision/cutmix/">CutMix data augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/mixup/">MixUp augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/randaugment/">RandAugment for Image Classification for Improved Robustness</a> <a class="nav-sublink2" href="/examples/vision/image_captioning/">Image captioning</a> <a class="nav-sublink2" href="/examples/vision/nl_image_search/">Natural language image search with a Dual Encoder</a> <a class="nav-sublink2 active" href="/examples/vision/visualizing_what_convnets_learn/">Visualizing what convnets learn</a> <a class="nav-sublink2" href="/examples/vision/integrated_gradients/">Model interpretability with Integrated Gradients</a> <a class="nav-sublink2" href="/examples/vision/probing_vits/">Investigating Vision Transformer representations</a> <a class="nav-sublink2" href="/examples/vision/grad_cam/">Grad-CAM class activation visualization</a> <a class="nav-sublink2" href="/examples/vision/near_dup_search/">Near-duplicate image search</a> <a class="nav-sublink2" href="/examples/vision/semantic_image_clustering/">Semantic Image Clustering</a> <a class="nav-sublink2" href="/examples/vision/siamese_contrastive/">Image similarity estimation using a Siamese Network with a contrastive loss</a> <a class="nav-sublink2" href="/examples/vision/siamese_network/">Image similarity estimation using a Siamese Network with a triplet loss</a> <a class="nav-sublink2" href="/examples/vision/metric_learning/">Metric learning for image similarity search</a> <a class="nav-sublink2" href="/examples/vision/metric_learning_tf_similarity/">Metric learning for image similarity search using TensorFlow Similarity</a> <a class="nav-sublink2" href="/examples/vision/nnclr/">Self-supervised contrastive learning with NNCLR</a> <a class="nav-sublink2" href="/examples/vision/video_classification/">Video Classification with a CNN-RNN Architecture</a> <a class="nav-sublink2" href="/examples/vision/conv_lstm/">Next-Frame Video Prediction with Convolutional LSTMs</a> <a class="nav-sublink2" href="/examples/vision/video_transformers/">Video Classification with Transformers</a> <a class="nav-sublink2" href="/examples/vision/vivit/">Video Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/bit/">Image Classification using BigTransfer (BiT)</a> <a class="nav-sublink2" href="/examples/vision/gradient_centralization/">Gradient Centralization for Better Training Performance</a> <a class="nav-sublink2" href="/examples/vision/token_learner/">Learning to tokenize in Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/knowledge_distillation/">Knowledge Distillation</a> <a class="nav-sublink2" href="/examples/vision/fixres/">FixRes: Fixing train-test resolution discrepancy</a> <a class="nav-sublink2" href="/examples/vision/cait/">Class Attention Image Transformers with LayerScale</a> <a class="nav-sublink2" href="/examples/vision/patch_convnet/">Augmenting convnets with aggregated attention</a> <a class="nav-sublink2" href="/examples/vision/learnable_resizer/">Learning to Resize</a> <a class="nav-sublink2" href="/examples/vision/adamatch/">Semi-supervision and domain adaptation with AdaMatch</a> <a class="nav-sublink2" href="/examples/vision/barlow_twins/">Barlow Twins for Contrastive SSL</a> <a class="nav-sublink2" href="/examples/vision/consistency_training/">Consistency training with supervision</a> <a class="nav-sublink2" href="/examples/vision/deit/">Distilling Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/focal_modulation_network/">Focal Modulation: A replacement for Self-Attention</a> <a class="nav-sublink2" href="/examples/vision/forwardforward/">Using the Forward-Forward Algorithm for Image Classification</a> <a class="nav-sublink2" href="/examples/vision/masked_image_modeling/">Masked image modeling with Autoencoders</a> <a class="nav-sublink2" href="/examples/vision/sam/">Segment Anything Model with 🤗Transformers</a> <a class="nav-sublink2" href="/examples/vision/segformer/">Semantic segmentation with SegFormer and Hugging Face Transformers</a> <a class="nav-sublink2" href="/examples/vision/simsiam/">Self-supervised contrastive learning with SimSiam</a> <a class="nav-sublink2" href="/examples/vision/supervised-contrastive-learning/">Supervised Contrastive Learning</a> <a class="nav-sublink2" href="/examples/vision/temporal_latent_bottleneck/">When Recurrence meets Transformers</a> <a class="nav-sublink2" href="/examples/vision/yolov8/">Efficient Object Detection with YOLOV8 and KerasCV</a> <a class="nav-sublink" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparameter Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> <a class="nav-link" href="/keras_cv/" role="tab" aria-selected="">KerasCV: Computer Vision Workflows</a> <a class="nav-link" href="/keras_nlp/" role="tab" aria-selected="">KerasNLP: Natural Language Workflows</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/vision/'>Computer Vision</a> / Visualizing what convnets learn </div> <div class='k-content'> <h1 id="visualizing-what-convnets-learn">Visualizing what convnets learn</h1> <p><strong>Author:</strong> <a href="https://twitter.com/fchollet">fchollet</a><br> <strong>Date created:</strong> 2020/05/29<br> <strong>Last modified:</strong> 2020/05/29<br> <strong>Description:</strong> Displaying the visual patterns that convnet filters respond to.</p> <div class='example_version_banner keras_3'>ⓘ This example uses Keras 3</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/vision/ipynb/visualizing_what_convnets_learn.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/examples/vision/visualizing_what_convnets_learn.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="introduction">Introduction</h2> <p>In this example, we look into what sort of visual patterns image classification models learn. We'll be using the <code>ResNet50V2</code> model, trained on the ImageNet dataset.</p> <p>Our process is simple: we will create input images that maximize the activation of specific filters in a target layer (picked somewhere in the middle of the model: layer <code>conv3_block4_out</code>). Such images represent a visualization of the pattern that the filter responds to.</p> <hr /> <h2 id="setup">Setup</h2> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">os</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s2">"KERAS_BACKEND"</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"tensorflow"</span> <span class="kn">import</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> <span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="nn">tf</span> <span class="c1"># The dimensions of our input image</span> <span class="n">img_width</span> <span class="o">=</span> <span class="mi">180</span> <span class="n">img_height</span> <span class="o">=</span> <span class="mi">180</span> <span class="c1"># Our target layer: we will visualize the filters from this layer.</span> <span class="c1"># See `model.summary()` for list of layer names, if you want to change this.</span> <span class="n">layer_name</span> <span class="o">=</span> <span class="s2">"conv3_block4_out"</span> </code></pre></div> <hr /> <h2 id="build-a-feature-extraction-model">Build a feature extraction model</h2> <div class="codehilite"><pre><span></span><code><span class="c1"># Build a ResNet50V2 model loaded with pre-trained ImageNet weights</span> <span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">applications</span><span class="o">.</span><span class="n">ResNet50V2</span><span class="p">(</span><span class="n">weights</span><span class="o">=</span><span class="s2">"imagenet"</span><span class="p">,</span> <span class="n">include_top</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> <span class="c1"># Set up a model that returns the activation values for our target layer</span> <span class="n">layer</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">get_layer</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="n">layer_name</span><span class="p">)</span> <span class="n">feature_extractor</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="o">=</span><span class="n">model</span><span class="o">.</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="o">=</span><span class="n">layer</span><span class="o">.</span><span class="n">output</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="set-up-the-gradient-ascent-process">Set up the gradient ascent process</h2> <p>The "loss" we will maximize is simply the mean of the activation of a specific filter in our target layer. To avoid border effects, we exclude border pixels.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">compute_loss</span><span class="p">(</span><span class="n">input_image</span><span class="p">,</span> <span class="n">filter_index</span><span class="p">):</span> <span class="n">activation</span> <span class="o">=</span> <span class="n">feature_extractor</span><span class="p">(</span><span class="n">input_image</span><span class="p">)</span> <span class="c1"># We avoid border artifacts by only involving non-border pixels in the loss.</span> <span class="n">filter_activation</span> <span class="o">=</span> <span class="n">activation</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">:</span><span class="o">-</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">:</span><span class="o">-</span><span class="mi">2</span><span class="p">,</span> <span class="n">filter_index</span><span class="p">]</span> <span class="k">return</span> <span class="n">tf</span><span class="o">.</span><span class="n">reduce_mean</span><span class="p">(</span><span class="n">filter_activation</span><span class="p">)</span> </code></pre></div> <p>Our gradient ascent function simply computes the gradients of the loss above with regard to the input image, and update the update image so as to move it towards a state that will activate the target filter more strongly.</p> <div class="codehilite"><pre><span></span><code><span class="nd">@tf</span><span class="o">.</span><span class="n">function</span> <span class="k">def</span> <span class="nf">gradient_ascent_step</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">filter_index</span><span class="p">,</span> <span class="n">learning_rate</span><span class="p">):</span> <span class="k">with</span> <span class="n">tf</span><span class="o">.</span><span class="n">GradientTape</span><span class="p">()</span> <span class="k">as</span> <span class="n">tape</span><span class="p">:</span> <span class="n">tape</span><span class="o">.</span><span class="n">watch</span><span class="p">(</span><span class="n">img</span><span class="p">)</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">compute_loss</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">filter_index</span><span class="p">)</span> <span class="c1"># Compute gradients.</span> <span class="n">grads</span> <span class="o">=</span> <span class="n">tape</span><span class="o">.</span><span class="n">gradient</span><span class="p">(</span><span class="n">loss</span><span class="p">,</span> <span class="n">img</span><span class="p">)</span> <span class="c1"># Normalize gradients.</span> <span class="n">grads</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">math</span><span class="o">.</span><span class="n">l2_normalize</span><span class="p">(</span><span class="n">grads</span><span class="p">)</span> <span class="n">img</span> <span class="o">+=</span> <span class="n">learning_rate</span> <span class="o">*</span> <span class="n">grads</span> <span class="k">return</span> <span class="n">loss</span><span class="p">,</span> <span class="n">img</span> </code></pre></div> <hr /> <h2 id="set-up-the-endtoend-filter-visualization-loop">Set up the end-to-end filter visualization loop</h2> <p>Our process is as follow:</p> <ul> <li>Start from a random image that is close to "all gray" (i.e. visually netural)</li> <li>Repeatedly apply the gradient ascent step function defined above</li> <li>Convert the resulting input image back to a displayable form, by normalizing it, center-cropping it, and restricting it to the [0, 255] range.</li> </ul> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">initialize_image</span><span class="p">():</span> <span class="c1"># We start from a gray image with some random noise</span> <span class="n">img</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="n">img_width</span><span class="p">,</span> <span class="n">img_height</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span> <span class="c1"># ResNet50V2 expects inputs in the range [-1, +1].</span> <span class="c1"># Here we scale our random inputs to [-0.125, +0.125]</span> <span class="k">return</span> <span class="p">(</span><span class="n">img</span> <span class="o">-</span> <span class="mf">0.5</span><span class="p">)</span> <span class="o">*</span> <span class="mf">0.25</span> <span class="k">def</span> <span class="nf">visualize_filter</span><span class="p">(</span><span class="n">filter_index</span><span class="p">):</span> <span class="c1"># We run gradient ascent for 20 steps</span> <span class="n">iterations</span> <span class="o">=</span> <span class="mi">30</span> <span class="n">learning_rate</span> <span class="o">=</span> <span class="mf">10.0</span> <span class="n">img</span> <span class="o">=</span> <span class="n">initialize_image</span><span class="p">()</span> <span class="k">for</span> <span class="n">iteration</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">iterations</span><span class="p">):</span> <span class="n">loss</span><span class="p">,</span> <span class="n">img</span> <span class="o">=</span> <span class="n">gradient_ascent_step</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">filter_index</span><span class="p">,</span> <span class="n">learning_rate</span><span class="p">)</span> <span class="c1"># Decode the resulting input image</span> <span class="n">img</span> <span class="o">=</span> <span class="n">deprocess_image</span><span class="p">(</span><span class="n">img</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span> <span class="k">return</span> <span class="n">loss</span><span class="p">,</span> <span class="n">img</span> <span class="k">def</span> <span class="nf">deprocess_image</span><span class="p">(</span><span class="n">img</span><span class="p">):</span> <span class="c1"># Normalize array: center on 0., ensure variance is 0.15</span> <span class="n">img</span> <span class="o">-=</span> <span class="n">img</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span> <span class="n">img</span> <span class="o">/=</span> <span class="n">img</span><span class="o">.</span><span class="n">std</span><span class="p">()</span> <span class="o">+</span> <span class="mf">1e-5</span> <span class="n">img</span> <span class="o">*=</span> <span class="mf">0.15</span> <span class="c1"># Center crop</span> <span class="n">img</span> <span class="o">=</span> <span class="n">img</span><span class="p">[</span><span class="mi">25</span><span class="p">:</span><span class="o">-</span><span class="mi">25</span><span class="p">,</span> <span class="mi">25</span><span class="p">:</span><span class="o">-</span><span class="mi">25</span><span class="p">,</span> <span class="p">:]</span> <span class="c1"># Clip to [0, 1]</span> <span class="n">img</span> <span class="o">+=</span> <span class="mf">0.5</span> <span class="n">img</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="c1"># Convert to RGB array</span> <span class="n">img</span> <span class="o">*=</span> <span class="mi">255</span> <span class="n">img</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">255</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">"uint8"</span><span class="p">)</span> <span class="k">return</span> <span class="n">img</span> </code></pre></div> <p>Let's try it out with filter 0 in the target layer:</p> <div class="codehilite"><pre><span></span><code><span class="kn">from</span> <span class="nn">IPython.display</span> <span class="kn">import</span> <span class="n">Image</span><span class="p">,</span> <span class="n">display</span> <span class="n">loss</span><span class="p">,</span> <span class="n">img</span> <span class="o">=</span> <span class="n">visualize_filter</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">save_img</span><span class="p">(</span><span class="s2">"0.png"</span><span class="p">,</span> <span class="n">img</span><span class="p">)</span> </code></pre></div> <p>This is what an input that maximizes the response of filter 0 in the target layer would look like:</p> <div class="codehilite"><pre><span></span><code><span class="n">display</span><span class="p">(</span><span class="n">Image</span><span class="p">(</span><span class="s2">"0.png"</span><span class="p">))</span> </code></pre></div> <p><img alt="png" src="/img/examples/vision/visualizing_what_convnets_learn/visualizing_what_convnets_learn_15_0.png" /></p> <hr /> <h2 id="visualize-the-first-64-filters-in-the-target-layer">Visualize the first 64 filters in the target layer</h2> <p>Now, let's make a 8x8 grid of the first 64 filters in the target layer to get of feel for the range of different visual patterns that the model has learned.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Compute image inputs that maximize per-filter activations</span> <span class="c1"># for the first 64 filters of our target layer</span> <span class="n">all_imgs</span> <span class="o">=</span> <span class="p">[]</span> <span class="k">for</span> <span class="n">filter_index</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">64</span><span class="p">):</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Processing filter </span><span class="si">%d</span><span class="s2">"</span> <span class="o">%</span> <span class="p">(</span><span class="n">filter_index</span><span class="p">,))</span> <span class="n">loss</span><span class="p">,</span> <span class="n">img</span> <span class="o">=</span> <span class="n">visualize_filter</span><span class="p">(</span><span class="n">filter_index</span><span class="p">)</span> <span class="n">all_imgs</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">img</span><span class="p">)</span> <span class="c1"># Build a black picture with enough space for</span> <span class="c1"># our 8 x 8 filters of size 128 x 128, with a 5px margin in between</span> <span class="n">margin</span> <span class="o">=</span> <span class="mi">5</span> <span class="n">n</span> <span class="o">=</span> <span class="mi">8</span> <span class="n">cropped_width</span> <span class="o">=</span> <span class="n">img_width</span> <span class="o">-</span> <span class="mi">25</span> <span class="o">*</span> <span class="mi">2</span> <span class="n">cropped_height</span> <span class="o">=</span> <span class="n">img_height</span> <span class="o">-</span> <span class="mi">25</span> <span class="o">*</span> <span class="mi">2</span> <span class="n">width</span> <span class="o">=</span> <span class="n">n</span> <span class="o">*</span> <span class="n">cropped_width</span> <span class="o">+</span> <span class="p">(</span><span class="n">n</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">margin</span> <span class="n">height</span> <span class="o">=</span> <span class="n">n</span> <span class="o">*</span> <span class="n">cropped_height</span> <span class="o">+</span> <span class="p">(</span><span class="n">n</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">margin</span> <span class="n">stitched_filters</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">width</span><span class="p">,</span> <span class="n">height</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span> <span class="c1"># Fill the picture with our saved filters</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n</span><span class="p">):</span> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n</span><span class="p">):</span> <span class="n">img</span> <span class="o">=</span> <span class="n">all_imgs</span><span class="p">[</span><span class="n">i</span> <span class="o">*</span> <span class="n">n</span> <span class="o">+</span> <span class="n">j</span><span class="p">]</span> <span class="n">stitched_filters</span><span class="p">[</span> <span class="p">(</span><span class="n">cropped_width</span> <span class="o">+</span> <span class="n">margin</span><span class="p">)</span> <span class="o">*</span> <span class="n">i</span> <span class="p">:</span> <span class="p">(</span><span class="n">cropped_width</span> <span class="o">+</span> <span class="n">margin</span><span class="p">)</span> <span class="o">*</span> <span class="n">i</span> <span class="o">+</span> <span class="n">cropped_width</span><span class="p">,</span> <span class="p">(</span><span class="n">cropped_height</span> <span class="o">+</span> <span class="n">margin</span><span class="p">)</span> <span class="o">*</span> <span class="n">j</span> <span class="p">:</span> <span class="p">(</span><span class="n">cropped_height</span> <span class="o">+</span> <span class="n">margin</span><span class="p">)</span> <span class="o">*</span> <span class="n">j</span> <span class="o">+</span> <span class="n">cropped_height</span><span class="p">,</span> <span class="p">:,</span> <span class="p">]</span> <span class="o">=</span> <span class="n">img</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">save_img</span><span class="p">(</span><span class="s2">"stiched_filters.png"</span><span class="p">,</span> <span class="n">stitched_filters</span><span class="p">)</span> <span class="kn">from</span> <span class="nn">IPython.display</span> <span class="kn">import</span> <span class="n">Image</span><span class="p">,</span> <span class="n">display</span> <span class="n">display</span><span class="p">(</span><span class="n">Image</span><span class="p">(</span><span class="s2">"stiched_filters.png"</span><span class="p">))</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Processing filter 0 Processing filter 1 Processing filter 2 Processing filter 3 Processing filter 4 Processing filter 5 Processing filter 6 Processing filter 7 Processing filter 8 Processing filter 9 Processing filter 10 Processing filter 11 Processing filter 12 Processing filter 13 Processing filter 14 Processing filter 15 Processing filter 16 Processing filter 17 Processing filter 18 Processing filter 19 Processing filter 20 Processing filter 21 Processing filter 22 Processing filter 23 Processing filter 24 Processing filter 25 Processing filter 26 Processing filter 27 Processing filter 28 Processing filter 29 Processing filter 30 Processing filter 31 Processing filter 32 Processing filter 33 Processing filter 34 Processing filter 35 Processing filter 36 Processing filter 37 Processing filter 38 Processing filter 39 Processing filter 40 Processing filter 41 Processing filter 42 Processing filter 43 Processing filter 44 Processing filter 45 Processing filter 46 Processing filter 47 Processing filter 48 Processing filter 49 Processing filter 50 Processing filter 51 Processing filter 52 Processing filter 53 Processing filter 54 Processing filter 55 Processing filter 56 Processing filter 57 Processing filter 58 Processing filter 59 Processing filter 60 Processing filter 61 Processing filter 62 Processing filter 63 </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/visualizing_what_convnets_learn/visualizing_what_convnets_learn_17_1.png" /></p> <p>Image classification models see the world by decomposing their inputs over a "vector basis" of texture filters such as these.</p> <p>See also <a href="https://blog.keras.io/how-convolutional-neural-networks-see-the-world.html">this old blog post</a> for analysis and interpretation.</p> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#visualizing-what-convnets-learn'>Visualizing what convnets learn</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#setup'>Setup</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#build-a-feature-extraction-model'>Build a feature extraction model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#set-up-the-gradient-ascent-process'>Set up the gradient ascent process</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#set-up-the-endtoend-filter-visualization-loop'>Set up the end-to-end filter visualization loop</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#visualize-the-first-64-filters-in-the-target-layer'>Visualize the first 64 filters in the target layer</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>