CINXE.COM

Model interpretability with Integrated Gradients

<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/vision/integrated_gradients/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Model interpretability with Integrated Gradients"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Model interpretability with Integrated Gradients"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Model interpretability with Integrated Gradients</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink active" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink2" href="/examples/vision/image_classification_from_scratch/">Image classification from scratch</a> <a class="nav-sublink2" href="/examples/vision/mnist_convnet/">Simple MNIST convnet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_efficientnet_fine_tuning/">Image classification via fine-tuning with EfficientNet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_with_vision_transformer/">Image classification with Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/attention_mil_classification/">Classification using Attention-based Deep Multiple Instance Learning</a> <a class="nav-sublink2" href="/examples/vision/mlp_image_classification/">Image classification with modern MLP models</a> <a class="nav-sublink2" href="/examples/vision/mobilevit/">A mobile-friendly Transformer-based model for image classification</a> <a class="nav-sublink2" href="/examples/vision/xray_classification_with_tpus/">Pneumonia Classification on TPU</a> <a class="nav-sublink2" href="/examples/vision/cct/">Compact Convolutional Transformers</a> <a class="nav-sublink2" href="/examples/vision/convmixer/">Image classification with ConvMixer</a> <a class="nav-sublink2" href="/examples/vision/eanet/">Image classification with EANet (External Attention Transformer)</a> <a class="nav-sublink2" href="/examples/vision/involution/">Involutional neural networks</a> <a class="nav-sublink2" href="/examples/vision/perceiver_image_classification/">Image classification with Perceiver</a> <a class="nav-sublink2" href="/examples/vision/reptile/">Few-Shot learning with Reptile</a> <a class="nav-sublink2" href="/examples/vision/semisupervised_simclr/">Semi-supervised image classification using contrastive pretraining with SimCLR</a> <a class="nav-sublink2" href="/examples/vision/swin_transformers/">Image classification with Swin Transformers</a> <a class="nav-sublink2" href="/examples/vision/vit_small_ds/">Train a Vision Transformer on small datasets</a> <a class="nav-sublink2" href="/examples/vision/shiftvit/">A Vision Transformer without Attention</a> <a class="nav-sublink2" href="/examples/vision/image_classification_using_global_context_vision_transformer/">Image Classification using Global Context Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/temporal_latent_bottleneck/">When Recurrence meets Transformers</a> <a class="nav-sublink2" href="/examples/vision/oxford_pets_image_segmentation/">Image segmentation with a U-Net-like architecture</a> <a class="nav-sublink2" href="/examples/vision/deeplabv3_plus/">Multiclass semantic segmentation using DeepLabV3+</a> <a class="nav-sublink2" href="/examples/vision/basnet_segmentation/">Highly accurate boundaries segmentation using BASNet</a> <a class="nav-sublink2" href="/examples/vision/fully_convolutional_network/">Image Segmentation using Composable Fully-Convolutional Networks</a> <a class="nav-sublink2" href="/examples/vision/retinanet/">Object Detection with RetinaNet</a> <a class="nav-sublink2" href="/examples/vision/keypoint_detection/">Keypoint Detection with Transfer Learning</a> <a class="nav-sublink2" href="/examples/vision/object_detection_using_vision_transformer/">Object detection with Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/3D_image_classification/">3D image classification from CT scans</a> <a class="nav-sublink2" href="/examples/vision/depth_estimation/">Monocular depth estimation</a> <a class="nav-sublink2" href="/examples/vision/nerf/">3D volumetric rendering with NeRF</a> <a class="nav-sublink2" href="/examples/vision/pointnet_segmentation/">Point cloud segmentation with PointNet</a> <a class="nav-sublink2" href="/examples/vision/pointnet/">Point cloud classification</a> <a class="nav-sublink2" href="/examples/vision/captcha_ocr/">OCR model for reading Captchas</a> <a class="nav-sublink2" href="/examples/vision/handwriting_recognition/">Handwriting recognition</a> <a class="nav-sublink2" href="/examples/vision/autoencoder/">Convolutional autoencoder for image denoising</a> <a class="nav-sublink2" href="/examples/vision/mirnet/">Low-light image enhancement using MIRNet</a> <a class="nav-sublink2" href="/examples/vision/super_resolution_sub_pixel/">Image Super-Resolution using an Efficient Sub-Pixel CNN</a> <a class="nav-sublink2" href="/examples/vision/edsr/">Enhanced Deep Residual Networks for single-image super-resolution</a> <a class="nav-sublink2" href="/examples/vision/zero_dce/">Zero-DCE for low-light image enhancement</a> <a class="nav-sublink2" href="/examples/vision/cutmix/">CutMix data augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/mixup/">MixUp augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/randaugment/">RandAugment for Image Classification for Improved Robustness</a> <a class="nav-sublink2" href="/examples/vision/image_captioning/">Image captioning</a> <a class="nav-sublink2" href="/examples/vision/nl_image_search/">Natural language image search with a Dual Encoder</a> <a class="nav-sublink2" href="/examples/vision/visualizing_what_convnets_learn/">Visualizing what convnets learn</a> <a class="nav-sublink2 active" href="/examples/vision/integrated_gradients/">Model interpretability with Integrated Gradients</a> <a class="nav-sublink2" href="/examples/vision/probing_vits/">Investigating Vision Transformer representations</a> <a class="nav-sublink2" href="/examples/vision/grad_cam/">Grad-CAM class activation visualization</a> <a class="nav-sublink2" href="/examples/vision/near_dup_search/">Near-duplicate image search</a> <a class="nav-sublink2" href="/examples/vision/semantic_image_clustering/">Semantic Image Clustering</a> <a class="nav-sublink2" href="/examples/vision/siamese_contrastive/">Image similarity estimation using a Siamese Network with a contrastive loss</a> <a class="nav-sublink2" href="/examples/vision/siamese_network/">Image similarity estimation using a Siamese Network with a triplet loss</a> <a class="nav-sublink2" href="/examples/vision/metric_learning/">Metric learning for image similarity search</a> <a class="nav-sublink2" href="/examples/vision/metric_learning_tf_similarity/">Metric learning for image similarity search using TensorFlow Similarity</a> <a class="nav-sublink2" href="/examples/vision/nnclr/">Self-supervised contrastive learning with NNCLR</a> <a class="nav-sublink2" href="/examples/vision/video_classification/">Video Classification with a CNN-RNN Architecture</a> <a class="nav-sublink2" href="/examples/vision/conv_lstm/">Next-Frame Video Prediction with Convolutional LSTMs</a> <a class="nav-sublink2" href="/examples/vision/video_transformers/">Video Classification with Transformers</a> <a class="nav-sublink2" href="/examples/vision/vivit/">Video Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/bit/">Image Classification using BigTransfer (BiT)</a> <a class="nav-sublink2" href="/examples/vision/gradient_centralization/">Gradient Centralization for Better Training Performance</a> <a class="nav-sublink2" href="/examples/vision/token_learner/">Learning to tokenize in Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/knowledge_distillation/">Knowledge Distillation</a> <a class="nav-sublink2" href="/examples/vision/fixres/">FixRes: Fixing train-test resolution discrepancy</a> <a class="nav-sublink2" href="/examples/vision/cait/">Class Attention Image Transformers with LayerScale</a> <a class="nav-sublink2" href="/examples/vision/patch_convnet/">Augmenting convnets with aggregated attention</a> <a class="nav-sublink2" href="/examples/vision/learnable_resizer/">Learning to Resize</a> <a class="nav-sublink2" href="/examples/vision/adamatch/">Semi-supervision and domain adaptation with AdaMatch</a> <a class="nav-sublink2" href="/examples/vision/barlow_twins/">Barlow Twins for Contrastive SSL</a> <a class="nav-sublink2" href="/examples/vision/consistency_training/">Consistency training with supervision</a> <a class="nav-sublink2" href="/examples/vision/deit/">Distilling Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/focal_modulation_network/">Focal Modulation: A replacement for Self-Attention</a> <a class="nav-sublink2" href="/examples/vision/forwardforward/">Using the Forward-Forward Algorithm for Image Classification</a> <a class="nav-sublink2" href="/examples/vision/masked_image_modeling/">Masked image modeling with Autoencoders</a> <a class="nav-sublink2" href="/examples/vision/sam/">Segment Anything Model with 🤗Transformers</a> <a class="nav-sublink2" href="/examples/vision/segformer/">Semantic segmentation with SegFormer and Hugging Face Transformers</a> <a class="nav-sublink2" href="/examples/vision/simsiam/">Self-supervised contrastive learning with SimSiam</a> <a class="nav-sublink2" href="/examples/vision/supervised-contrastive-learning/">Supervised Contrastive Learning</a> <a class="nav-sublink2" href="/examples/vision/yolov8/">Efficient Object Detection with YOLOV8 and KerasCV</a> <a class="nav-sublink" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparam Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/vision/'>Computer Vision</a> / Model interpretability with Integrated Gradients </div> <div class='k-content'> <h1 id="model-interpretability-with-integrated-gradients">Model interpretability with Integrated Gradients</h1> <p><strong>Author:</strong> <a href="https://twitter.com/A_K_Nain">A_K_Nain</a><br> <strong>Date created:</strong> 2020/06/02<br> <strong>Last modified:</strong> 2020/06/02<br> <strong>Description:</strong> How to obtain integrated gradients for a classification model.</p> <div class='example_version_banner keras_3'>ⓘ This example uses Keras 3</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/vision/ipynb/integrated_gradients.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/examples/vision/integrated_gradients.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="integrated-gradients">Integrated Gradients</h2> <p><a href="https://arxiv.org/abs/1703.01365">Integrated Gradients</a> is a technique for attributing a classification model's prediction to its input features. It is a model interpretability technique: you can use it to visualize the relationship between input features and model predictions.</p> <p>Integrated Gradients is a variation on computing the gradient of the prediction output with regard to features of the input. To compute integrated gradients, we need to perform the following steps:</p> <ol> <li> <p>Identify the input and the output. In our case, the input is an image and the output is the last layer of our model (dense layer with softmax activation).</p> </li> <li> <p>Compute which features are important to a neural network when making a prediction on a particular data point. To identify these features, we need to choose a baseline input. A baseline input can be a black image (all pixel values set to zero) or random noise. The shape of the baseline input needs to be the same as our input image, e.g. (299, 299, 3).</p> </li> <li> <p>Interpolate the baseline for a given number of steps. The number of steps represents the steps we need in the gradient approximation for a given input image. The number of steps is a hyperparameter. The authors recommend using anywhere between 20 and 1000 steps.</p> </li> <li> <p>Preprocess these interpolated images and do a forward pass.</p> </li> <li>Get the gradients for these interpolated images.</li> <li>Approximate the gradients integral using the trapezoidal rule.</li> </ol> <p>To read in-depth about integrated gradients and why this method works, consider reading this excellent <a href="https://distill.pub/2020/attribution-baselines/">article</a>.</p> <p><strong>References:</strong></p> <ul> <li>Integrated Gradients original <a href="https://arxiv.org/abs/1703.01365">paper</a></li> <li><a href="https://github.com/ankurtaly/Integrated-Gradients">Original implementation</a></li> </ul> <hr /> <h2 id="setup">Setup</h2> <div class="codehilite"><pre><span></span><code><span class="kn">import</span><span class="w"> </span><span class="nn">numpy</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">np</span> <span class="kn">import</span><span class="w"> </span><span class="nn">matplotlib.pyplot</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">plt</span> <span class="kn">from</span><span class="w"> </span><span class="nn">scipy</span><span class="w"> </span><span class="kn">import</span> <span class="n">ndimage</span> <span class="kn">from</span><span class="w"> </span><span class="nn">IPython.display</span><span class="w"> </span><span class="kn">import</span> <span class="n">Image</span><span class="p">,</span> <span class="n">display</span> <span class="kn">import</span><span class="w"> </span><span class="nn">tensorflow</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">tf</span> <span class="kn">import</span><span class="w"> </span><span class="nn">keras</span> <span class="kn">from</span><span class="w"> </span><span class="nn">keras</span><span class="w"> </span><span class="kn">import</span> <span class="n">layers</span> <span class="kn">from</span><span class="w"> </span><span class="nn">keras.applications</span><span class="w"> </span><span class="kn">import</span> <span class="n">xception</span> <span class="c1"># Size of the input image</span> <span class="n">img_size</span> <span class="o">=</span> <span class="p">(</span><span class="mi">299</span><span class="p">,</span> <span class="mi">299</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span> <span class="c1"># Load Xception model with imagenet weights</span> <span class="n">model</span> <span class="o">=</span> <span class="n">xception</span><span class="o">.</span><span class="n">Xception</span><span class="p">(</span><span class="n">weights</span><span class="o">=</span><span class="s2">&quot;imagenet&quot;</span><span class="p">)</span> <span class="c1"># The local path to our target image</span> <span class="n">img_path</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">get_file</span><span class="p">(</span><span class="s2">&quot;elephant.jpg&quot;</span><span class="p">,</span> <span class="s2">&quot;https://i.imgur.com/Bvro0YD.png&quot;</span><span class="p">)</span> <span class="n">display</span><span class="p">(</span><span class="n">Image</span><span class="p">(</span><span class="n">img_path</span><span class="p">))</span> </code></pre></div> <p><img alt="jpeg" src="/img/examples/vision/integrated_gradients/integrated_gradients_3_0.jpg" /></p> <hr /> <h2 id="integrated-gradients-algorithm">Integrated Gradients algorithm</h2> <div class="codehilite"><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="nf">get_img_array</span><span class="p">(</span><span class="n">img_path</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">299</span><span class="p">,</span> <span class="mi">299</span><span class="p">)):</span> <span class="c1"># `img` is a PIL image of size 299x299</span> <span class="n">img</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">load_img</span><span class="p">(</span><span class="n">img_path</span><span class="p">,</span> <span class="n">target_size</span><span class="o">=</span><span class="n">size</span><span class="p">)</span> <span class="c1"># `array` is a float32 Numpy array of shape (299, 299, 3)</span> <span class="n">array</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">img_to_array</span><span class="p">(</span><span class="n">img</span><span class="p">)</span> <span class="c1"># We add a dimension to transform our array into a &quot;batch&quot;</span> <span class="c1"># of size (1, 299, 299, 3)</span> <span class="n">array</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">array</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="k">return</span> <span class="n">array</span> <span class="k">def</span><span class="w"> </span><span class="nf">get_gradients</span><span class="p">(</span><span class="n">img_input</span><span class="p">,</span> <span class="n">top_pred_idx</span><span class="p">):</span> <span class="w"> </span><span class="sd">&quot;&quot;&quot;Computes the gradients of outputs w.r.t input image.</span> <span class="sd"> Args:</span> <span class="sd"> img_input: 4D image tensor</span> <span class="sd"> top_pred_idx: Predicted label for the input image</span> <span class="sd"> Returns:</span> <span class="sd"> Gradients of the predictions w.r.t img_input</span> <span class="sd"> &quot;&quot;&quot;</span> <span class="n">images</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">img_input</span><span class="p">,</span> <span class="n">tf</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="k">with</span> <span class="n">tf</span><span class="o">.</span><span class="n">GradientTape</span><span class="p">()</span> <span class="k">as</span> <span class="n">tape</span><span class="p">:</span> <span class="n">tape</span><span class="o">.</span><span class="n">watch</span><span class="p">(</span><span class="n">images</span><span class="p">)</span> <span class="n">preds</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">images</span><span class="p">)</span> <span class="n">top_class</span> <span class="o">=</span> <span class="n">preds</span><span class="p">[:,</span> <span class="n">top_pred_idx</span><span class="p">]</span> <span class="n">grads</span> <span class="o">=</span> <span class="n">tape</span><span class="o">.</span><span class="n">gradient</span><span class="p">(</span><span class="n">top_class</span><span class="p">,</span> <span class="n">images</span><span class="p">)</span> <span class="k">return</span> <span class="n">grads</span> <span class="k">def</span><span class="w"> </span><span class="nf">get_integrated_gradients</span><span class="p">(</span><span class="n">img_input</span><span class="p">,</span> <span class="n">top_pred_idx</span><span class="p">,</span> <span class="n">baseline</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">num_steps</span><span class="o">=</span><span class="mi">50</span><span class="p">):</span> <span class="w"> </span><span class="sd">&quot;&quot;&quot;Computes Integrated Gradients for a predicted label.</span> <span class="sd"> Args:</span> <span class="sd"> img_input (ndarray): Original image</span> <span class="sd"> top_pred_idx: Predicted label for the input image</span> <span class="sd"> baseline (ndarray): The baseline image to start with for interpolation</span> <span class="sd"> num_steps: Number of interpolation steps between the baseline</span> <span class="sd"> and the input used in the computation of integrated gradients. These</span> <span class="sd"> steps along determine the integral approximation error. By default,</span> <span class="sd"> num_steps is set to 50.</span> <span class="sd"> Returns:</span> <span class="sd"> Integrated gradients w.r.t input image</span> <span class="sd"> &quot;&quot;&quot;</span> <span class="c1"># If baseline is not provided, start with a black image</span> <span class="c1"># having same size as the input image.</span> <span class="k">if</span> <span class="n">baseline</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> <span class="n">baseline</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">img_size</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="k">else</span><span class="p">:</span> <span class="n">baseline</span> <span class="o">=</span> <span class="n">baseline</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="c1"># 1. Do interpolation.</span> <span class="n">img_input</span> <span class="o">=</span> <span class="n">img_input</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="n">interpolated_image</span> <span class="o">=</span> <span class="p">[</span> <span class="n">baseline</span> <span class="o">+</span> <span class="p">(</span><span class="n">step</span> <span class="o">/</span> <span class="n">num_steps</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">img_input</span> <span class="o">-</span> <span class="n">baseline</span><span class="p">)</span> <span class="k">for</span> <span class="n">step</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_steps</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="p">]</span> <span class="n">interpolated_image</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">interpolated_image</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="c1"># 2. Preprocess the interpolated images</span> <span class="n">interpolated_image</span> <span class="o">=</span> <span class="n">xception</span><span class="o">.</span><span class="n">preprocess_input</span><span class="p">(</span><span class="n">interpolated_image</span><span class="p">)</span> <span class="c1"># 3. Get the gradients</span> <span class="n">grads</span> <span class="o">=</span> <span class="p">[]</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">img</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">interpolated_image</span><span class="p">):</span> <span class="n">img</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="n">grad</span> <span class="o">=</span> <span class="n">get_gradients</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">top_pred_idx</span><span class="o">=</span><span class="n">top_pred_idx</span><span class="p">)</span> <span class="n">grads</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">grad</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="n">grads</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">convert_to_tensor</span><span class="p">(</span><span class="n">grads</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="c1"># 4. Approximate the integral using the trapezoidal rule</span> <span class="n">grads</span> <span class="o">=</span> <span class="p">(</span><span class="n">grads</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="n">grads</span><span class="p">[</span><span class="mi">1</span><span class="p">:])</span> <span class="o">/</span> <span class="mf">2.0</span> <span class="n">avg_grads</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">reduce_mean</span><span class="p">(</span><span class="n">grads</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="c1"># 5. Calculate integrated gradients and return</span> <span class="n">integrated_grads</span> <span class="o">=</span> <span class="p">(</span><span class="n">img_input</span> <span class="o">-</span> <span class="n">baseline</span><span class="p">)</span> <span class="o">*</span> <span class="n">avg_grads</span> <span class="k">return</span> <span class="n">integrated_grads</span> <span class="k">def</span><span class="w"> </span><span class="nf">random_baseline_integrated_gradients</span><span class="p">(</span> <span class="n">img_input</span><span class="p">,</span> <span class="n">top_pred_idx</span><span class="p">,</span> <span class="n">num_steps</span><span class="o">=</span><span class="mi">50</span><span class="p">,</span> <span class="n">num_runs</span><span class="o">=</span><span class="mi">2</span> <span class="p">):</span> <span class="w"> </span><span class="sd">&quot;&quot;&quot;Generates a number of random baseline images.</span> <span class="sd"> Args:</span> <span class="sd"> img_input (ndarray): 3D image</span> <span class="sd"> top_pred_idx: Predicted label for the input image</span> <span class="sd"> num_steps: Number of interpolation steps between the baseline</span> <span class="sd"> and the input used in the computation of integrated gradients. These</span> <span class="sd"> steps along determine the integral approximation error. By default,</span> <span class="sd"> num_steps is set to 50.</span> <span class="sd"> num_runs: number of baseline images to generate</span> <span class="sd"> Returns:</span> <span class="sd"> Averaged integrated gradients for `num_runs` baseline images</span> <span class="sd"> &quot;&quot;&quot;</span> <span class="c1"># 1. List to keep track of Integrated Gradients (IG) for all the images</span> <span class="n">integrated_grads</span> <span class="o">=</span> <span class="p">[]</span> <span class="c1"># 2. Get the integrated gradients for all the baselines</span> <span class="k">for</span> <span class="n">run</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_runs</span><span class="p">):</span> <span class="n">baseline</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">(</span><span class="n">img_size</span><span class="p">)</span> <span class="o">*</span> <span class="mi">255</span> <span class="n">igrads</span> <span class="o">=</span> <span class="n">get_integrated_gradients</span><span class="p">(</span> <span class="n">img_input</span><span class="o">=</span><span class="n">img_input</span><span class="p">,</span> <span class="n">top_pred_idx</span><span class="o">=</span><span class="n">top_pred_idx</span><span class="p">,</span> <span class="n">baseline</span><span class="o">=</span><span class="n">baseline</span><span class="p">,</span> <span class="n">num_steps</span><span class="o">=</span><span class="n">num_steps</span><span class="p">,</span> <span class="p">)</span> <span class="n">integrated_grads</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">igrads</span><span class="p">)</span> <span class="c1"># 3. Return the average integrated gradients for the image</span> <span class="n">integrated_grads</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">convert_to_tensor</span><span class="p">(</span><span class="n">integrated_grads</span><span class="p">)</span> <span class="k">return</span> <span class="n">tf</span><span class="o">.</span><span class="n">reduce_mean</span><span class="p">(</span><span class="n">integrated_grads</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="helper-class-for-visualizing-gradients-and-integrated-gradients">Helper class for visualizing gradients and integrated gradients</h2> <div class="codehilite"><pre><span></span><code><span class="k">class</span><span class="w"> </span><span class="nc">GradVisualizer</span><span class="p">:</span> <span class="w"> </span><span class="sd">&quot;&quot;&quot;Plot gradients of the outputs w.r.t an input image.&quot;&quot;&quot;</span> <span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">positive_channel</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">negative_channel</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="k">if</span> <span class="n">positive_channel</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">positive_channel</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">255</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span> <span class="k">else</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">positive_channel</span> <span class="o">=</span> <span class="n">positive_channel</span> <span class="k">if</span> <span class="n">negative_channel</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">negative_channel</span> <span class="o">=</span> <span class="p">[</span><span class="mi">255</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span> <span class="k">else</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">negative_channel</span> <span class="o">=</span> <span class="n">negative_channel</span> <span class="k">def</span><span class="w"> </span><span class="nf">apply_polarity</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">attributions</span><span class="p">,</span> <span class="n">polarity</span><span class="p">):</span> <span class="k">if</span> <span class="n">polarity</span> <span class="o">==</span> <span class="s2">&quot;positive&quot;</span><span class="p">:</span> <span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">attributions</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="k">else</span><span class="p">:</span> <span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">attributions</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span> <span class="k">def</span><span class="w"> </span><span class="nf">apply_linear_transformation</span><span class="p">(</span> <span class="bp">self</span><span class="p">,</span> <span class="n">attributions</span><span class="p">,</span> <span class="n">clip_above_percentile</span><span class="o">=</span><span class="mf">99.9</span><span class="p">,</span> <span class="n">clip_below_percentile</span><span class="o">=</span><span class="mf">70.0</span><span class="p">,</span> <span class="n">lower_end</span><span class="o">=</span><span class="mf">0.2</span><span class="p">,</span> <span class="p">):</span> <span class="c1"># 1. Get the thresholds</span> <span class="n">m</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_thresholded_attributions</span><span class="p">(</span> <span class="n">attributions</span><span class="p">,</span> <span class="n">percentage</span><span class="o">=</span><span class="mi">100</span> <span class="o">-</span> <span class="n">clip_above_percentile</span> <span class="p">)</span> <span class="n">e</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_thresholded_attributions</span><span class="p">(</span> <span class="n">attributions</span><span class="p">,</span> <span class="n">percentage</span><span class="o">=</span><span class="mi">100</span> <span class="o">-</span> <span class="n">clip_below_percentile</span> <span class="p">)</span> <span class="c1"># 2. Transform the attributions by a linear function f(x) = a*x + b such that</span> <span class="c1"># f(m) = 1.0 and f(e) = lower_end</span> <span class="n">transformed_attributions</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">lower_end</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">abs</span><span class="p">(</span><span class="n">attributions</span><span class="p">)</span> <span class="o">-</span> <span class="n">e</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span> <span class="n">m</span> <span class="o">-</span> <span class="n">e</span> <span class="p">)</span> <span class="o">+</span> <span class="n">lower_end</span> <span class="c1"># 3. Make sure that the sign of transformed attributions is the same as original attributions</span> <span class="n">transformed_attributions</span> <span class="o">*=</span> <span class="n">np</span><span class="o">.</span><span class="n">sign</span><span class="p">(</span><span class="n">attributions</span><span class="p">)</span> <span class="c1"># 4. Only keep values that are bigger than the lower_end</span> <span class="n">transformed_attributions</span> <span class="o">*=</span> <span class="n">transformed_attributions</span> <span class="o">&gt;=</span> <span class="n">lower_end</span> <span class="c1"># 5. Clip values and return</span> <span class="n">transformed_attributions</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">transformed_attributions</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">1.0</span><span class="p">)</span> <span class="k">return</span> <span class="n">transformed_attributions</span> <span class="k">def</span><span class="w"> </span><span class="nf">get_thresholded_attributions</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">attributions</span><span class="p">,</span> <span class="n">percentage</span><span class="p">):</span> <span class="k">if</span> <span class="n">percentage</span> <span class="o">==</span> <span class="mf">100.0</span><span class="p">:</span> <span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">min</span><span class="p">(</span><span class="n">attributions</span><span class="p">)</span> <span class="c1"># 1. Flatten the attributions</span> <span class="n">flatten_attr</span> <span class="o">=</span> <span class="n">attributions</span><span class="o">.</span><span class="n">flatten</span><span class="p">()</span> <span class="c1"># 2. Get the sum of the attributions</span> <span class="n">total</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">flatten_attr</span><span class="p">)</span> <span class="c1"># 3. Sort the attributions from largest to smallest.</span> <span class="n">sorted_attributions</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">sort</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">abs</span><span class="p">(</span><span class="n">flatten_attr</span><span class="p">))[::</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="c1"># 4. Calculate the percentage of the total sum that each attribution</span> <span class="c1"># and the values about it contribute.</span> <span class="n">cum_sum</span> <span class="o">=</span> <span class="mf">100.0</span> <span class="o">*</span> <span class="n">np</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">sorted_attributions</span><span class="p">)</span> <span class="o">/</span> <span class="n">total</span> <span class="c1"># 5. Threshold the attributions by the percentage</span> <span class="n">indices_to_consider</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">cum_sum</span> <span class="o">&gt;=</span> <span class="n">percentage</span><span class="p">)[</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span> <span class="c1"># 6. Select the desired attributions and return</span> <span class="n">attributions</span> <span class="o">=</span> <span class="n">sorted_attributions</span><span class="p">[</span><span class="n">indices_to_consider</span><span class="p">]</span> <span class="k">return</span> <span class="n">attributions</span> <span class="k">def</span><span class="w"> </span><span class="nf">binarize</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">attributions</span><span class="p">,</span> <span class="n">threshold</span><span class="o">=</span><span class="mf">0.001</span><span class="p">):</span> <span class="k">return</span> <span class="n">attributions</span> <span class="o">&gt;</span> <span class="n">threshold</span> <span class="k">def</span><span class="w"> </span><span class="nf">morphological_cleanup_fn</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">attributions</span><span class="p">,</span> <span class="n">structure</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">))):</span> <span class="n">closed</span> <span class="o">=</span> <span class="n">ndimage</span><span class="o">.</span><span class="n">grey_closing</span><span class="p">(</span><span class="n">attributions</span><span class="p">,</span> <span class="n">structure</span><span class="o">=</span><span class="n">structure</span><span class="p">)</span> <span class="n">opened</span> <span class="o">=</span> <span class="n">ndimage</span><span class="o">.</span><span class="n">grey_opening</span><span class="p">(</span><span class="n">closed</span><span class="p">,</span> <span class="n">structure</span><span class="o">=</span><span class="n">structure</span><span class="p">)</span> <span class="k">return</span> <span class="n">opened</span> <span class="k">def</span><span class="w"> </span><span class="nf">draw_outlines</span><span class="p">(</span> <span class="bp">self</span><span class="p">,</span> <span class="n">attributions</span><span class="p">,</span> <span class="n">percentage</span><span class="o">=</span><span class="mi">90</span><span class="p">,</span> <span class="n">connected_component_structure</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">)),</span> <span class="p">):</span> <span class="c1"># 1. Binarize the attributions.</span> <span class="n">attributions</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">binarize</span><span class="p">(</span><span class="n">attributions</span><span class="p">)</span> <span class="c1"># 2. Fill the gaps</span> <span class="n">attributions</span> <span class="o">=</span> <span class="n">ndimage</span><span class="o">.</span><span class="n">binary_fill_holes</span><span class="p">(</span><span class="n">attributions</span><span class="p">)</span> <span class="c1"># 3. Compute connected components</span> <span class="n">connected_components</span><span class="p">,</span> <span class="n">num_comp</span> <span class="o">=</span> <span class="n">ndimage</span><span class="o">.</span><span class="n">label</span><span class="p">(</span> <span class="n">attributions</span><span class="p">,</span> <span class="n">structure</span><span class="o">=</span><span class="n">connected_component_structure</span> <span class="p">)</span> <span class="c1"># 4. Sum up the attributions for each component</span> <span class="n">total</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">attributions</span><span class="p">[</span><span class="n">connected_components</span> <span class="o">&gt;</span> <span class="mi">0</span><span class="p">])</span> <span class="n">component_sums</span> <span class="o">=</span> <span class="p">[]</span> <span class="k">for</span> <span class="n">comp</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">num_comp</span> <span class="o">+</span> <span class="mi">1</span><span class="p">):</span> <span class="n">mask</span> <span class="o">=</span> <span class="n">connected_components</span> <span class="o">==</span> <span class="n">comp</span> <span class="n">component_sum</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">attributions</span><span class="p">[</span><span class="n">mask</span><span class="p">])</span> <span class="n">component_sums</span><span class="o">.</span><span class="n">append</span><span class="p">((</span><span class="n">component_sum</span><span class="p">,</span> <span class="n">mask</span><span class="p">))</span> <span class="c1"># 5. Compute the percentage of top components to keep</span> <span class="n">sorted_sums_and_masks</span> <span class="o">=</span> <span class="nb">sorted</span><span class="p">(</span><span class="n">component_sums</span><span class="p">,</span> <span class="n">key</span><span class="o">=</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">x</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">reverse</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="n">sorted_sums</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="o">*</span><span class="n">sorted_sums_and_masks</span><span class="p">))[</span><span class="mi">0</span><span class="p">]</span> <span class="n">cumulative_sorted_sums</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">cumsum</span><span class="p">(</span><span class="n">sorted_sums</span><span class="p">)</span> <span class="n">cutoff_threshold</span> <span class="o">=</span> <span class="n">percentage</span> <span class="o">*</span> <span class="n">total</span> <span class="o">/</span> <span class="mi">100</span> <span class="n">cutoff_idx</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">cumulative_sorted_sums</span> <span class="o">&gt;=</span> <span class="n">cutoff_threshold</span><span class="p">)[</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span> <span class="k">if</span> <span class="n">cutoff_idx</span> <span class="o">&gt;</span> <span class="mi">2</span><span class="p">:</span> <span class="n">cutoff_idx</span> <span class="o">=</span> <span class="mi">2</span> <span class="c1"># 6. Set the values for the kept components</span> <span class="n">border_mask</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">attributions</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">cutoff_idx</span> <span class="o">+</span> <span class="mi">1</span><span class="p">):</span> <span class="n">border_mask</span><span class="p">[</span><span class="n">sorted_sums_and_masks</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">1</span><span class="p">]]</span> <span class="o">=</span> <span class="mi">1</span> <span class="c1"># 7. Make the mask hollow and show only the border</span> <span class="n">eroded_mask</span> <span class="o">=</span> <span class="n">ndimage</span><span class="o">.</span><span class="n">binary_erosion</span><span class="p">(</span><span class="n">border_mask</span><span class="p">,</span> <span class="n">iterations</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="n">border_mask</span><span class="p">[</span><span class="n">eroded_mask</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span> <span class="c1"># 8. Return the outlined mask</span> <span class="k">return</span> <span class="n">border_mask</span> <span class="k">def</span><span class="w"> </span><span class="nf">process_grads</span><span class="p">(</span> <span class="bp">self</span><span class="p">,</span> <span class="n">image</span><span class="p">,</span> <span class="n">attributions</span><span class="p">,</span> <span class="n">polarity</span><span class="o">=</span><span class="s2">&quot;positive&quot;</span><span class="p">,</span> <span class="n">clip_above_percentile</span><span class="o">=</span><span class="mf">99.9</span><span class="p">,</span> <span class="n">clip_below_percentile</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">morphological_cleanup</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">structure</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">)),</span> <span class="n">outlines</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">outlines_component_percentage</span><span class="o">=</span><span class="mi">90</span><span class="p">,</span> <span class="n">overlay</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="p">):</span> <span class="k">if</span> <span class="n">polarity</span> <span class="ow">not</span> <span class="ow">in</span> <span class="p">[</span><span class="s2">&quot;positive&quot;</span><span class="p">,</span> <span class="s2">&quot;negative&quot;</span><span class="p">]:</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span> <span class="sa">f</span><span class="s2">&quot;&quot;&quot; Allowed polarity values: &#39;positive&#39; or &#39;negative&#39;</span> <span class="s2"> but provided </span><span class="si">{</span><span class="n">polarity</span><span class="si">}</span><span class="s2">&quot;&quot;&quot;</span> <span class="p">)</span> <span class="k">if</span> <span class="n">clip_above_percentile</span> <span class="o">&lt;</span> <span class="mi">0</span> <span class="ow">or</span> <span class="n">clip_above_percentile</span> <span class="o">&gt;</span> <span class="mi">100</span><span class="p">:</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;clip_above_percentile must be in [0, 100]&quot;</span><span class="p">)</span> <span class="k">if</span> <span class="n">clip_below_percentile</span> <span class="o">&lt;</span> <span class="mi">0</span> <span class="ow">or</span> <span class="n">clip_below_percentile</span> <span class="o">&gt;</span> <span class="mi">100</span><span class="p">:</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;clip_below_percentile must be in [0, 100]&quot;</span><span class="p">)</span> <span class="c1"># 1. Apply polarity</span> <span class="k">if</span> <span class="n">polarity</span> <span class="o">==</span> <span class="s2">&quot;positive&quot;</span><span class="p">:</span> <span class="n">attributions</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">apply_polarity</span><span class="p">(</span><span class="n">attributions</span><span class="p">,</span> <span class="n">polarity</span><span class="o">=</span><span class="n">polarity</span><span class="p">)</span> <span class="n">channel</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">positive_channel</span> <span class="k">else</span><span class="p">:</span> <span class="n">attributions</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">apply_polarity</span><span class="p">(</span><span class="n">attributions</span><span class="p">,</span> <span class="n">polarity</span><span class="o">=</span><span class="n">polarity</span><span class="p">)</span> <span class="n">attributions</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">abs</span><span class="p">(</span><span class="n">attributions</span><span class="p">)</span> <span class="n">channel</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">negative_channel</span> <span class="c1"># 2. Take average over the channels</span> <span class="n">attributions</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">average</span><span class="p">(</span><span class="n">attributions</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span> <span class="c1"># 3. Apply linear transformation to the attributions</span> <span class="n">attributions</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">apply_linear_transformation</span><span class="p">(</span> <span class="n">attributions</span><span class="p">,</span> <span class="n">clip_above_percentile</span><span class="o">=</span><span class="n">clip_above_percentile</span><span class="p">,</span> <span class="n">clip_below_percentile</span><span class="o">=</span><span class="n">clip_below_percentile</span><span class="p">,</span> <span class="n">lower_end</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> <span class="p">)</span> <span class="c1"># 4. Cleanup</span> <span class="k">if</span> <span class="n">morphological_cleanup</span><span class="p">:</span> <span class="n">attributions</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">morphological_cleanup_fn</span><span class="p">(</span> <span class="n">attributions</span><span class="p">,</span> <span class="n">structure</span><span class="o">=</span><span class="n">structure</span> <span class="p">)</span> <span class="c1"># 5. Draw the outlines</span> <span class="k">if</span> <span class="n">outlines</span><span class="p">:</span> <span class="n">attributions</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">draw_outlines</span><span class="p">(</span> <span class="n">attributions</span><span class="p">,</span> <span class="n">percentage</span><span class="o">=</span><span class="n">outlines_component_percentage</span> <span class="p">)</span> <span class="c1"># 6. Expand the channel axis and convert to RGB</span> <span class="n">attributions</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">attributions</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span> <span class="o">*</span> <span class="n">channel</span> <span class="c1"># 7.Superimpose on the original image</span> <span class="k">if</span> <span class="n">overlay</span><span class="p">:</span> <span class="n">attributions</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">clip</span><span class="p">((</span><span class="n">attributions</span> <span class="o">*</span> <span class="mf">0.8</span> <span class="o">+</span> <span class="n">image</span><span class="p">),</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">255</span><span class="p">)</span> <span class="k">return</span> <span class="n">attributions</span> <span class="k">def</span><span class="w"> </span><span class="nf">visualize</span><span class="p">(</span> <span class="bp">self</span><span class="p">,</span> <span class="n">image</span><span class="p">,</span> <span class="n">gradients</span><span class="p">,</span> <span class="n">integrated_gradients</span><span class="p">,</span> <span class="n">polarity</span><span class="o">=</span><span class="s2">&quot;positive&quot;</span><span class="p">,</span> <span class="n">clip_above_percentile</span><span class="o">=</span><span class="mf">99.9</span><span class="p">,</span> <span class="n">clip_below_percentile</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">morphological_cleanup</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">structure</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">)),</span> <span class="n">outlines</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">outlines_component_percentage</span><span class="o">=</span><span class="mi">90</span><span class="p">,</span> <span class="n">overlay</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">15</span><span class="p">,</span> <span class="mi">8</span><span class="p">),</span> <span class="p">):</span> <span class="c1"># 1. Make two copies of the original image</span> <span class="n">img1</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">image</span><span class="p">)</span> <span class="n">img2</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">image</span><span class="p">)</span> <span class="c1"># 2. Process the normal gradients</span> <span class="n">grads_attr</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">process_grads</span><span class="p">(</span> <span class="n">image</span><span class="o">=</span><span class="n">img1</span><span class="p">,</span> <span class="n">attributions</span><span class="o">=</span><span class="n">gradients</span><span class="p">,</span> <span class="n">polarity</span><span class="o">=</span><span class="n">polarity</span><span class="p">,</span> <span class="n">clip_above_percentile</span><span class="o">=</span><span class="n">clip_above_percentile</span><span class="p">,</span> <span class="n">clip_below_percentile</span><span class="o">=</span><span class="n">clip_below_percentile</span><span class="p">,</span> <span class="n">morphological_cleanup</span><span class="o">=</span><span class="n">morphological_cleanup</span><span class="p">,</span> <span class="n">structure</span><span class="o">=</span><span class="n">structure</span><span class="p">,</span> <span class="n">outlines</span><span class="o">=</span><span class="n">outlines</span><span class="p">,</span> <span class="n">outlines_component_percentage</span><span class="o">=</span><span class="n">outlines_component_percentage</span><span class="p">,</span> <span class="n">overlay</span><span class="o">=</span><span class="n">overlay</span><span class="p">,</span> <span class="p">)</span> <span class="c1"># 3. Process the integrated gradients</span> <span class="n">igrads_attr</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">process_grads</span><span class="p">(</span> <span class="n">image</span><span class="o">=</span><span class="n">img2</span><span class="p">,</span> <span class="n">attributions</span><span class="o">=</span><span class="n">integrated_gradients</span><span class="p">,</span> <span class="n">polarity</span><span class="o">=</span><span class="n">polarity</span><span class="p">,</span> <span class="n">clip_above_percentile</span><span class="o">=</span><span class="n">clip_above_percentile</span><span class="p">,</span> <span class="n">clip_below_percentile</span><span class="o">=</span><span class="n">clip_below_percentile</span><span class="p">,</span> <span class="n">morphological_cleanup</span><span class="o">=</span><span class="n">morphological_cleanup</span><span class="p">,</span> <span class="n">structure</span><span class="o">=</span><span class="n">structure</span><span class="p">,</span> <span class="n">outlines</span><span class="o">=</span><span class="n">outlines</span><span class="p">,</span> <span class="n">outlines_component_percentage</span><span class="o">=</span><span class="n">outlines_component_percentage</span><span class="p">,</span> <span class="n">overlay</span><span class="o">=</span><span class="n">overlay</span><span class="p">,</span> <span class="p">)</span> <span class="n">_</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="n">figsize</span><span class="p">)</span> <span class="n">ax</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">image</span><span class="p">)</span> <span class="n">ax</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">grads_attr</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">))</span> <span class="n">ax</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">igrads_attr</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">))</span> <span class="n">ax</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">&quot;Input&quot;</span><span class="p">)</span> <span class="n">ax</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">&quot;Normal gradients&quot;</span><span class="p">)</span> <span class="n">ax</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">&quot;Integrated gradients&quot;</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> </code></pre></div> <hr /> <h2 id="lets-testdrive-it">Let's test-drive it</h2> <div class="codehilite"><pre><span></span><code><span class="c1"># 1. Convert the image to numpy array</span> <span class="n">img</span> <span class="o">=</span> <span class="n">get_img_array</span><span class="p">(</span><span class="n">img_path</span><span class="p">)</span> <span class="c1"># 2. Keep a copy of the original image</span> <span class="n">orig_img</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">img</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">)</span> <span class="c1"># 3. Preprocess the image</span> <span class="n">img_processed</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">xception</span><span class="o">.</span><span class="n">preprocess_input</span><span class="p">(</span><span class="n">img</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="c1"># 4. Get model predictions</span> <span class="n">preds</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">img_processed</span><span class="p">)</span> <span class="n">top_pred_idx</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">preds</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Predicted:&quot;</span><span class="p">,</span> <span class="n">top_pred_idx</span><span class="p">,</span> <span class="n">xception</span><span class="o">.</span><span class="n">decode_predictions</span><span class="p">(</span><span class="n">preds</span><span class="p">,</span> <span class="n">top</span><span class="o">=</span><span class="mi">1</span><span class="p">)[</span><span class="mi">0</span><span class="p">])</span> <span class="c1"># 5. Get the gradients of the last layer for the predicted label</span> <span class="n">grads</span> <span class="o">=</span> <span class="n">get_gradients</span><span class="p">(</span><span class="n">img_processed</span><span class="p">,</span> <span class="n">top_pred_idx</span><span class="o">=</span><span class="n">top_pred_idx</span><span class="p">)</span> <span class="c1"># 6. Get the integrated gradients</span> <span class="n">igrads</span> <span class="o">=</span> <span class="n">random_baseline_integrated_gradients</span><span class="p">(</span> <span class="n">np</span><span class="o">.</span><span class="n">copy</span><span class="p">(</span><span class="n">orig_img</span><span class="p">),</span> <span class="n">top_pred_idx</span><span class="o">=</span><span class="n">top_pred_idx</span><span class="p">,</span> <span class="n">num_steps</span><span class="o">=</span><span class="mi">50</span><span class="p">,</span> <span class="n">num_runs</span><span class="o">=</span><span class="mi">2</span> <span class="p">)</span> <span class="c1"># 7. Process the gradients and plot</span> <span class="n">vis</span> <span class="o">=</span> <span class="n">GradVisualizer</span><span class="p">()</span> <span class="n">vis</span><span class="o">.</span><span class="n">visualize</span><span class="p">(</span> <span class="n">image</span><span class="o">=</span><span class="n">orig_img</span><span class="p">,</span> <span class="n">gradients</span><span class="o">=</span><span class="n">grads</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">integrated_gradients</span><span class="o">=</span><span class="n">igrads</span><span class="o">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">clip_above_percentile</span><span class="o">=</span><span class="mi">99</span><span class="p">,</span> <span class="n">clip_below_percentile</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="p">)</span> <span class="n">vis</span><span class="o">.</span><span class="n">visualize</span><span class="p">(</span> <span class="n">image</span><span class="o">=</span><span class="n">orig_img</span><span class="p">,</span> <span class="n">gradients</span><span class="o">=</span><span class="n">grads</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">integrated_gradients</span><span class="o">=</span><span class="n">igrads</span><span class="o">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">clip_above_percentile</span><span class="o">=</span><span class="mi">95</span><span class="p">,</span> <span class="n">clip_below_percentile</span><span class="o">=</span><span class="mi">28</span><span class="p">,</span> <span class="n">morphological_cleanup</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">outlines</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 1/1 ━━━━━━━━━━━━━━━━━━━━ 5s 5s/step WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1699486705.534012 86541 device_compiler.h:187] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process. Predicted: tf.Tensor(386, shape=(), dtype=int64) [(&#39;n02504458&#39;, &#39;African_elephant&#39;, 0.8871446)] </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/integrated_gradients/integrated_gradients_9_3.png" /></p> <p><img alt="png" src="/img/examples/vision/integrated_gradients/integrated_gradients_9_4.png" /></p> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#model-interpretability-with-integrated-gradients'>Model interpretability with Integrated Gradients</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#integrated-gradients'>Integrated Gradients</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#setup'>Setup</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#integrated-gradients-algorithm'>Integrated Gradients algorithm</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#helper-class-for-visualizing-gradients-and-integrated-gradients'>Helper class for visualizing gradients and integrated gradients</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#lets-testdrive-it'>Let's test-drive it</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>

Pages: 1 2 3 4 5 6 7 8 9 10