CINXE.COM

Involutional neural networks

<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/vision/involution/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Involutional neural networks"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Involutional neural networks"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Involutional neural networks</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink active" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink2" href="/examples/vision/image_classification_from_scratch/">Image classification from scratch</a> <a class="nav-sublink2" href="/examples/vision/mnist_convnet/">Simple MNIST convnet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_efficientnet_fine_tuning/">Image classification via fine-tuning with EfficientNet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_with_vision_transformer/">Image classification with Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/attention_mil_classification/">Classification using Attention-based Deep Multiple Instance Learning</a> <a class="nav-sublink2" href="/examples/vision/mlp_image_classification/">Image classification with modern MLP models</a> <a class="nav-sublink2" href="/examples/vision/mobilevit/">A mobile-friendly Transformer-based model for image classification</a> <a class="nav-sublink2" href="/examples/vision/xray_classification_with_tpus/">Pneumonia Classification on TPU</a> <a class="nav-sublink2" href="/examples/vision/cct/">Compact Convolutional Transformers</a> <a class="nav-sublink2" href="/examples/vision/convmixer/">Image classification with ConvMixer</a> <a class="nav-sublink2" href="/examples/vision/eanet/">Image classification with EANet (External Attention Transformer)</a> <a class="nav-sublink2 active" href="/examples/vision/involution/">Involutional neural networks</a> <a class="nav-sublink2" href="/examples/vision/perceiver_image_classification/">Image classification with Perceiver</a> <a class="nav-sublink2" href="/examples/vision/reptile/">Few-Shot learning with Reptile</a> <a class="nav-sublink2" href="/examples/vision/semisupervised_simclr/">Semi-supervised image classification using contrastive pretraining with SimCLR</a> <a class="nav-sublink2" href="/examples/vision/swin_transformers/">Image classification with Swin Transformers</a> <a class="nav-sublink2" href="/examples/vision/vit_small_ds/">Train a Vision Transformer on small datasets</a> <a class="nav-sublink2" href="/examples/vision/shiftvit/">A Vision Transformer without Attention</a> <a class="nav-sublink2" href="/examples/vision/image_classification_using_global_context_vision_transformer/">Image Classification using Global Context Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/temporal_latent_bottleneck/">When Recurrence meets Transformers</a> <a class="nav-sublink2" href="/examples/vision/oxford_pets_image_segmentation/">Image segmentation with a U-Net-like architecture</a> <a class="nav-sublink2" href="/examples/vision/deeplabv3_plus/">Multiclass semantic segmentation using DeepLabV3+</a> <a class="nav-sublink2" href="/examples/vision/basnet_segmentation/">Highly accurate boundaries segmentation using BASNet</a> <a class="nav-sublink2" href="/examples/vision/fully_convolutional_network/">Image Segmentation using Composable Fully-Convolutional Networks</a> <a class="nav-sublink2" href="/examples/vision/retinanet/">Object Detection with RetinaNet</a> <a class="nav-sublink2" href="/examples/vision/keypoint_detection/">Keypoint Detection with Transfer Learning</a> <a class="nav-sublink2" href="/examples/vision/object_detection_using_vision_transformer/">Object detection with Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/3D_image_classification/">3D image classification from CT scans</a> <a class="nav-sublink2" href="/examples/vision/depth_estimation/">Monocular depth estimation</a> <a class="nav-sublink2" href="/examples/vision/nerf/">3D volumetric rendering with NeRF</a> <a class="nav-sublink2" href="/examples/vision/pointnet_segmentation/">Point cloud segmentation with PointNet</a> <a class="nav-sublink2" href="/examples/vision/pointnet/">Point cloud classification</a> <a class="nav-sublink2" href="/examples/vision/captcha_ocr/">OCR model for reading Captchas</a> <a class="nav-sublink2" href="/examples/vision/handwriting_recognition/">Handwriting recognition</a> <a class="nav-sublink2" href="/examples/vision/autoencoder/">Convolutional autoencoder for image denoising</a> <a class="nav-sublink2" href="/examples/vision/mirnet/">Low-light image enhancement using MIRNet</a> <a class="nav-sublink2" href="/examples/vision/super_resolution_sub_pixel/">Image Super-Resolution using an Efficient Sub-Pixel CNN</a> <a class="nav-sublink2" href="/examples/vision/edsr/">Enhanced Deep Residual Networks for single-image super-resolution</a> <a class="nav-sublink2" href="/examples/vision/zero_dce/">Zero-DCE for low-light image enhancement</a> <a class="nav-sublink2" href="/examples/vision/cutmix/">CutMix data augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/mixup/">MixUp augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/randaugment/">RandAugment for Image Classification for Improved Robustness</a> <a class="nav-sublink2" href="/examples/vision/image_captioning/">Image captioning</a> <a class="nav-sublink2" href="/examples/vision/nl_image_search/">Natural language image search with a Dual Encoder</a> <a class="nav-sublink2" href="/examples/vision/visualizing_what_convnets_learn/">Visualizing what convnets learn</a> <a class="nav-sublink2" href="/examples/vision/integrated_gradients/">Model interpretability with Integrated Gradients</a> <a class="nav-sublink2" href="/examples/vision/probing_vits/">Investigating Vision Transformer representations</a> <a class="nav-sublink2" href="/examples/vision/grad_cam/">Grad-CAM class activation visualization</a> <a class="nav-sublink2" href="/examples/vision/near_dup_search/">Near-duplicate image search</a> <a class="nav-sublink2" href="/examples/vision/semantic_image_clustering/">Semantic Image Clustering</a> <a class="nav-sublink2" href="/examples/vision/siamese_contrastive/">Image similarity estimation using a Siamese Network with a contrastive loss</a> <a class="nav-sublink2" href="/examples/vision/siamese_network/">Image similarity estimation using a Siamese Network with a triplet loss</a> <a class="nav-sublink2" href="/examples/vision/metric_learning/">Metric learning for image similarity search</a> <a class="nav-sublink2" href="/examples/vision/metric_learning_tf_similarity/">Metric learning for image similarity search using TensorFlow Similarity</a> <a class="nav-sublink2" href="/examples/vision/nnclr/">Self-supervised contrastive learning with NNCLR</a> <a class="nav-sublink2" href="/examples/vision/video_classification/">Video Classification with a CNN-RNN Architecture</a> <a class="nav-sublink2" href="/examples/vision/conv_lstm/">Next-Frame Video Prediction with Convolutional LSTMs</a> <a class="nav-sublink2" href="/examples/vision/video_transformers/">Video Classification with Transformers</a> <a class="nav-sublink2" href="/examples/vision/vivit/">Video Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/bit/">Image Classification using BigTransfer (BiT)</a> <a class="nav-sublink2" href="/examples/vision/gradient_centralization/">Gradient Centralization for Better Training Performance</a> <a class="nav-sublink2" href="/examples/vision/token_learner/">Learning to tokenize in Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/knowledge_distillation/">Knowledge Distillation</a> <a class="nav-sublink2" href="/examples/vision/fixres/">FixRes: Fixing train-test resolution discrepancy</a> <a class="nav-sublink2" href="/examples/vision/cait/">Class Attention Image Transformers with LayerScale</a> <a class="nav-sublink2" href="/examples/vision/patch_convnet/">Augmenting convnets with aggregated attention</a> <a class="nav-sublink2" href="/examples/vision/learnable_resizer/">Learning to Resize</a> <a class="nav-sublink2" href="/examples/vision/adamatch/">Semi-supervision and domain adaptation with AdaMatch</a> <a class="nav-sublink2" href="/examples/vision/barlow_twins/">Barlow Twins for Contrastive SSL</a> <a class="nav-sublink2" href="/examples/vision/consistency_training/">Consistency training with supervision</a> <a class="nav-sublink2" href="/examples/vision/deit/">Distilling Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/focal_modulation_network/">Focal Modulation: A replacement for Self-Attention</a> <a class="nav-sublink2" href="/examples/vision/forwardforward/">Using the Forward-Forward Algorithm for Image Classification</a> <a class="nav-sublink2" href="/examples/vision/masked_image_modeling/">Masked image modeling with Autoencoders</a> <a class="nav-sublink2" href="/examples/vision/sam/">Segment Anything Model with 🤗Transformers</a> <a class="nav-sublink2" href="/examples/vision/segformer/">Semantic segmentation with SegFormer and Hugging Face Transformers</a> <a class="nav-sublink2" href="/examples/vision/simsiam/">Self-supervised contrastive learning with SimSiam</a> <a class="nav-sublink2" href="/examples/vision/supervised-contrastive-learning/">Supervised Contrastive Learning</a> <a class="nav-sublink2" href="/examples/vision/yolov8/">Efficient Object Detection with YOLOV8 and KerasCV</a> <a class="nav-sublink" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparam Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/vision/'>Computer Vision</a> / Involutional neural networks </div> <div class='k-content'> <h1 id="involutional-neural-networks">Involutional neural networks</h1> <p><strong>Author:</strong> <a href="https://twitter.com/ariG23498">Aritra Roy Gosthipaty</a><br> <strong>Date created:</strong> 2021/07/25<br> <strong>Last modified:</strong> 2021/07/25<br> <strong>Description:</strong> Deep dive into location-specific and channel-agnostic "involution" kernels.</p> <div class='example_version_banner keras_3'>ⓘ This example uses Keras 3</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/vision/ipynb/involution.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/examples/vision/involution.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="introduction">Introduction</h2> <p>Convolution has been the basis of most modern neural networks for computer vision. A convolution kernel is spatial-agnostic and channel-specific. Because of this, it isn't able to adapt to different visual patterns with respect to different spatial locations. Along with location-related problems, the receptive field of convolution creates challenges with regard to capturing long-range spatial interactions.</p> <p>To address the above issues, Li et. al. rethink the properties of convolution in <a href="https://arxiv.org/abs/2103.06255">Involution: Inverting the Inherence of Convolution for VisualRecognition</a>. The authors propose the "involution kernel", that is location-specific and channel-agnostic. Due to the location-specific nature of the operation, the authors say that self-attention falls under the design paradigm of involution.</p> <p>This example describes the involution kernel, compares two image classification models, one with convolution and the other with involution, and also tries drawing a parallel with the self-attention layer.</p> <hr /> <h2 id="setup">Setup</h2> <div class="codehilite"><pre><span></span><code><span class="kn">import</span><span class="w"> </span><span class="nn">os</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s2">&quot;KERAS_BACKEND&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="s2">&quot;tensorflow&quot;</span> <span class="kn">import</span><span class="w"> </span><span class="nn">tensorflow</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">tf</span> <span class="kn">import</span><span class="w"> </span><span class="nn">keras</span> <span class="kn">import</span><span class="w"> </span><span class="nn">matplotlib.pyplot</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">plt</span> <span class="c1"># Set seed for reproducibility.</span> <span class="n">tf</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">set_seed</span><span class="p">(</span><span class="mi">42</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="convolution">Convolution</h2> <p>Convolution remains the mainstay of deep neural networks for computer vision. To understand Involution, it is necessary to talk about the convolution operation.</p> <p><img alt="Imgur" src="https://i.imgur.com/MSKLsm5.png" /></p> <p>Consider an input tensor <strong>X</strong> with dimensions <strong>H</strong>, <strong>W</strong> and <strong>C_in</strong>. We take a collection of <strong>C_out</strong> convolution kernels each of shape <strong>K</strong>, <strong>K</strong>, <strong>C_in</strong>. With the multiply-add operation between the input tensor and the kernels we obtain an output tensor <strong>Y</strong> with dimensions <strong>H</strong>, <strong>W</strong>, <strong>C_out</strong>.</p> <p>In the diagram above <code>C_out=3</code>. This makes the output tensor of shape H, W and 3. One can notice that the convoltuion kernel does not depend on the spatial position of the input tensor which makes it <strong>location-agnostic</strong>. On the other hand, each channel in the output tensor is based on a specific convolution filter which makes is <strong>channel-specific</strong>.</p> <hr /> <h2 id="involution">Involution</h2> <p>The idea is to have an operation that is both <strong>location-specific</strong> and <strong>channel-agnostic</strong>. Trying to implement these specific properties poses a challenge. With a fixed number of involution kernels (for each spatial position) we will <strong>not</strong> be able to process variable-resolution input tensors.</p> <p>To solve this problem, the authors have considered <em>generating</em> each kernel conditioned on specific spatial positions. With this method, we should be able to process variable-resolution input tensors with ease. The diagram below provides an intuition on this kernel generation method.</p> <p><img alt="Imgur" src="https://i.imgur.com/jtrGGQg.png" /></p> <div class="codehilite"><pre><span></span><code><span class="k">class</span><span class="w"> </span><span class="nc">Involution</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span> <span class="bp">self</span><span class="p">,</span> <span class="n">channel</span><span class="p">,</span> <span class="n">group_number</span><span class="p">,</span> <span class="n">kernel_size</span><span class="p">,</span> <span class="n">stride</span><span class="p">,</span> <span class="n">reduction_ratio</span><span class="p">,</span> <span class="n">name</span> <span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">)</span> <span class="c1"># Initialize the parameters.</span> <span class="bp">self</span><span class="o">.</span><span class="n">channel</span> <span class="o">=</span> <span class="n">channel</span> <span class="bp">self</span><span class="o">.</span><span class="n">group_number</span> <span class="o">=</span> <span class="n">group_number</span> <span class="bp">self</span><span class="o">.</span><span class="n">kernel_size</span> <span class="o">=</span> <span class="n">kernel_size</span> <span class="bp">self</span><span class="o">.</span><span class="n">stride</span> <span class="o">=</span> <span class="n">stride</span> <span class="bp">self</span><span class="o">.</span><span class="n">reduction_ratio</span> <span class="o">=</span> <span class="n">reduction_ratio</span> <span class="k">def</span><span class="w"> </span><span class="nf">build</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_shape</span><span class="p">):</span> <span class="c1"># Get the shape of the input.</span> <span class="p">(</span><span class="n">_</span><span class="p">,</span> <span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">,</span> <span class="n">num_channels</span><span class="p">)</span> <span class="o">=</span> <span class="n">input_shape</span> <span class="c1"># Scale the height and width with respect to the strides.</span> <span class="n">height</span> <span class="o">=</span> <span class="n">height</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">stride</span> <span class="n">width</span> <span class="o">=</span> <span class="n">width</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">stride</span> <span class="c1"># Define a layer that average pools the input tensor</span> <span class="c1"># if stride is more than 1.</span> <span class="bp">self</span><span class="o">.</span><span class="n">stride_layer</span> <span class="o">=</span> <span class="p">(</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">AveragePooling2D</span><span class="p">(</span> <span class="n">pool_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">stride</span><span class="p">,</span> <span class="n">strides</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">stride</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s2">&quot;same&quot;</span> <span class="p">)</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">stride</span> <span class="o">&gt;</span> <span class="mi">1</span> <span class="k">else</span> <span class="n">tf</span><span class="o">.</span><span class="n">identity</span> <span class="p">)</span> <span class="c1"># Define the kernel generation layer.</span> <span class="bp">self</span><span class="o">.</span><span class="n">kernel_gen</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span> <span class="p">[</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span> <span class="n">filters</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">channel</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">reduction_ratio</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">1</span> <span class="p">),</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">BatchNormalization</span><span class="p">(),</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(),</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span> <span class="n">filters</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">kernel_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">kernel_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">group_number</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="p">),</span> <span class="p">]</span> <span class="p">)</span> <span class="c1"># Define reshape layers</span> <span class="bp">self</span><span class="o">.</span><span class="n">kernel_reshape</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Reshape</span><span class="p">(</span> <span class="n">target_shape</span><span class="o">=</span><span class="p">(</span> <span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">kernel_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">kernel_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">group_number</span><span class="p">,</span> <span class="p">)</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_patches_reshape</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Reshape</span><span class="p">(</span> <span class="n">target_shape</span><span class="o">=</span><span class="p">(</span> <span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">kernel_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">kernel_size</span><span class="p">,</span> <span class="n">num_channels</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">group_number</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">group_number</span><span class="p">,</span> <span class="p">)</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_reshape</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Reshape</span><span class="p">(</span> <span class="n">target_shape</span><span class="o">=</span><span class="p">(</span><span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">,</span> <span class="n">num_channels</span><span class="p">)</span> <span class="p">)</span> <span class="k">def</span><span class="w"> </span><span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="c1"># Generate the kernel with respect to the input tensor.</span> <span class="c1"># B, H, W, K*K*G</span> <span class="n">kernel_input</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">stride_layer</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">kernel</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">kernel_gen</span><span class="p">(</span><span class="n">kernel_input</span><span class="p">)</span> <span class="c1"># reshape the kerenl</span> <span class="c1"># B, H, W, K*K, 1, G</span> <span class="n">kernel</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">kernel_reshape</span><span class="p">(</span><span class="n">kernel</span><span class="p">)</span> <span class="c1"># Extract input patches.</span> <span class="c1"># B, H, W, K*K*C</span> <span class="n">input_patches</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">extract_patches</span><span class="p">(</span> <span class="n">images</span><span class="o">=</span><span class="n">x</span><span class="p">,</span> <span class="n">sizes</span><span class="o">=</span><span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">kernel_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">kernel_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">strides</span><span class="o">=</span><span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">stride</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">stride</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">rates</span><span class="o">=</span><span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">padding</span><span class="o">=</span><span class="s2">&quot;SAME&quot;</span><span class="p">,</span> <span class="p">)</span> <span class="c1"># Reshape the input patches to align with later operations.</span> <span class="c1"># B, H, W, K*K, C//G, G</span> <span class="n">input_patches</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">input_patches_reshape</span><span class="p">(</span><span class="n">input_patches</span><span class="p">)</span> <span class="c1"># Compute the multiply-add operation of kernels and patches.</span> <span class="c1"># B, H, W, K*K, C//G, G</span> <span class="n">output</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">multiply</span><span class="p">(</span><span class="n">kernel</span><span class="p">,</span> <span class="n">input_patches</span><span class="p">)</span> <span class="c1"># B, H, W, C//G, G</span> <span class="n">output</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">reduce_sum</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span> <span class="c1"># Reshape the output kernel.</span> <span class="c1"># B, H, W, C</span> <span class="n">output</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_reshape</span><span class="p">(</span><span class="n">output</span><span class="p">)</span> <span class="c1"># Return the output tensor and the kernel.</span> <span class="k">return</span> <span class="n">output</span><span class="p">,</span> <span class="n">kernel</span> </code></pre></div> <hr /> <h2 id="testing-the-involution-layer">Testing the Involution layer</h2> <div class="codehilite"><pre><span></span><code><span class="c1"># Define the input tensor.</span> <span class="n">input_tensor</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">normal</span><span class="p">((</span><span class="mi">32</span><span class="p">,</span> <span class="mi">256</span><span class="p">,</span> <span class="mi">256</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span> <span class="c1"># Compute involution with stride 1.</span> <span class="n">output_tensor</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">Involution</span><span class="p">(</span> <span class="n">channel</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">group_number</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">reduction_ratio</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;inv_1&quot;</span> <span class="p">)(</span><span class="n">input_tensor</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;with stride 1 ouput shape: </span><span class="si">{</span><span class="n">output_tensor</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span> <span class="c1"># Compute involution with stride 2.</span> <span class="n">output_tensor</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">Involution</span><span class="p">(</span> <span class="n">channel</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">group_number</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">reduction_ratio</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;inv_2&quot;</span> <span class="p">)(</span><span class="n">input_tensor</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;with stride 2 ouput shape: </span><span class="si">{</span><span class="n">output_tensor</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span> <span class="c1"># Compute involution with stride 1, channel 16 and reduction ratio 2.</span> <span class="n">output_tensor</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">Involution</span><span class="p">(</span> <span class="n">channel</span><span class="o">=</span><span class="mi">16</span><span class="p">,</span> <span class="n">group_number</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">reduction_ratio</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;inv_3&quot;</span> <span class="p">)(</span><span class="n">input_tensor</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span> <span class="s2">&quot;with channel 16 and reduction ratio 2 ouput shape: </span><span class="si">{}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">output_tensor</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>with stride 1 ouput shape: (32, 256, 256, 3) with stride 2 ouput shape: (32, 128, 128, 3) with channel 16 and reduction ratio 2 ouput shape: (32, 256, 256, 3) </code></pre></div> </div> <hr /> <h2 id="image-classification">Image Classification</h2> <p>In this section, we will build an image-classifier model. There will be two models one with convolutions and the other with involutions.</p> <p>The image-classification model is heavily inspired by this <a href="https://www.tensorflow.org/tutorials/images/cnn">Convolutional Neural Network (CNN)</a> tutorial from Google.</p> <hr /> <h2 id="get-the-cifar10-dataset">Get the CIFAR10 Dataset</h2> <div class="codehilite"><pre><span></span><code><span class="c1"># Load the CIFAR10 dataset.</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;loading the CIFAR10 dataset...&quot;</span><span class="p">)</span> <span class="p">(</span> <span class="p">(</span><span class="n">train_images</span><span class="p">,</span> <span class="n">train_labels</span><span class="p">),</span> <span class="p">(</span> <span class="n">test_images</span><span class="p">,</span> <span class="n">test_labels</span><span class="p">,</span> <span class="p">),</span> <span class="p">)</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">datasets</span><span class="o">.</span><span class="n">cifar10</span><span class="o">.</span><span class="n">load_data</span><span class="p">()</span> <span class="c1"># Normalize pixel values to be between 0 and 1.</span> <span class="p">(</span><span class="n">train_images</span><span class="p">,</span> <span class="n">test_images</span><span class="p">)</span> <span class="o">=</span> <span class="p">(</span><span class="n">train_images</span> <span class="o">/</span> <span class="mf">255.0</span><span class="p">,</span> <span class="n">test_images</span> <span class="o">/</span> <span class="mf">255.0</span><span class="p">)</span> <span class="c1"># Shuffle and batch the dataset.</span> <span class="n">train_ds</span> <span class="o">=</span> <span class="p">(</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">from_tensor_slices</span><span class="p">((</span><span class="n">train_images</span><span class="p">,</span> <span class="n">train_labels</span><span class="p">))</span> <span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="mi">256</span><span class="p">)</span> <span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="mi">256</span><span class="p">)</span> <span class="p">)</span> <span class="n">test_ds</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">from_tensor_slices</span><span class="p">((</span><span class="n">test_images</span><span class="p">,</span> <span class="n">test_labels</span><span class="p">))</span><span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="mi">256</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>loading the CIFAR10 dataset... </code></pre></div> </div> <hr /> <h2 id="visualise-the-data">Visualise the data</h2> <div class="codehilite"><pre><span></span><code><span class="n">class_names</span> <span class="o">=</span> <span class="p">[</span> <span class="s2">&quot;airplane&quot;</span><span class="p">,</span> <span class="s2">&quot;automobile&quot;</span><span class="p">,</span> <span class="s2">&quot;bird&quot;</span><span class="p">,</span> <span class="s2">&quot;cat&quot;</span><span class="p">,</span> <span class="s2">&quot;deer&quot;</span><span class="p">,</span> <span class="s2">&quot;dog&quot;</span><span class="p">,</span> <span class="s2">&quot;frog&quot;</span><span class="p">,</span> <span class="s2">&quot;horse&quot;</span><span class="p">,</span> <span class="s2">&quot;ship&quot;</span><span class="p">,</span> <span class="s2">&quot;truck&quot;</span><span class="p">,</span> <span class="p">]</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">25</span><span class="p">):</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">xticks</span><span class="p">([])</span> <span class="n">plt</span><span class="o">.</span><span class="n">yticks</span><span class="p">([])</span> <span class="n">plt</span><span class="o">.</span><span class="n">grid</span><span class="p">(</span><span class="kc">False</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">train_images</span><span class="p">[</span><span class="n">i</span><span class="p">])</span> <span class="n">plt</span><span class="o">.</span><span class="n">xlabel</span><span class="p">(</span><span class="n">class_names</span><span class="p">[</span><span class="n">train_labels</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">0</span><span class="p">]])</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> </code></pre></div> <p><img alt="png" src="/img/examples/vision/involution/involution_13_0.png" /></p> <hr /> <h2 id="convolutional-neural-network">Convolutional Neural Network</h2> <div class="codehilite"><pre><span></span><code><span class="c1"># Build the conv model.</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;building the convolution model...&quot;</span><span class="p">)</span> <span class="n">conv_model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span> <span class="p">[</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">input_shape</span><span class="o">=</span><span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">padding</span><span class="o">=</span><span class="s2">&quot;same&quot;</span><span class="p">),</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">&quot;relu1&quot;</span><span class="p">),</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">MaxPooling2D</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">)),</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">padding</span><span class="o">=</span><span class="s2">&quot;same&quot;</span><span class="p">),</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">&quot;relu2&quot;</span><span class="p">),</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">MaxPooling2D</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">)),</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">padding</span><span class="o">=</span><span class="s2">&quot;same&quot;</span><span class="p">),</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">&quot;relu3&quot;</span><span class="p">),</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Flatten</span><span class="p">(),</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">&quot;relu&quot;</span><span class="p">),</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">10</span><span class="p">),</span> <span class="p">]</span> <span class="p">)</span> <span class="c1"># Compile the mode with the necessary loss function and optimizer.</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;compiling the convolution model...&quot;</span><span class="p">)</span> <span class="n">conv_model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">optimizer</span><span class="o">=</span><span class="s2">&quot;adam&quot;</span><span class="p">,</span> <span class="n">loss</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">SparseCategoricalCrossentropy</span><span class="p">(</span><span class="n">from_logits</span><span class="o">=</span><span class="kc">True</span><span class="p">),</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;accuracy&quot;</span><span class="p">],</span> <span class="p">)</span> <span class="c1"># Train the model.</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;conv model training...&quot;</span><span class="p">)</span> <span class="n">conv_hist</span> <span class="o">=</span> <span class="n">conv_model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">train_ds</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="n">test_ds</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>building the convolution model... compiling the convolution model... conv model training... Epoch 1/20 196/196 ━━━━━━━━━━━━━━━━━━━━ 6s 15ms/step - accuracy: 0.3068 - loss: 1.9000 - val_accuracy: 0.4861 - val_loss: 1.4593 Epoch 2/20 196/196 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - accuracy: 0.5153 - loss: 1.3603 - val_accuracy: 0.5741 - val_loss: 1.1913 Epoch 3/20 196/196 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.5949 - loss: 1.1517 - val_accuracy: 0.6095 - val_loss: 1.0965 Epoch 4/20 196/196 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.6414 - loss: 1.0330 - val_accuracy: 0.6260 - val_loss: 1.0635 Epoch 5/20 196/196 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.6690 - loss: 0.9485 - val_accuracy: 0.6622 - val_loss: 0.9833 Epoch 6/20 196/196 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.6951 - loss: 0.8764 - val_accuracy: 0.6783 - val_loss: 0.9413 Epoch 7/20 196/196 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.7122 - loss: 0.8167 - val_accuracy: 0.6856 - val_loss: 0.9134 Epoch 8/20 196/196 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - accuracy: 0.7299 - loss: 0.7709 - val_accuracy: 0.7001 - val_loss: 0.8792 Epoch 9/20 196/196 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - accuracy: 0.7467 - loss: 0.7288 - val_accuracy: 0.6992 - val_loss: 0.8821 Epoch 10/20 196/196 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - accuracy: 0.7591 - loss: 0.6982 - val_accuracy: 0.7235 - val_loss: 0.8237 Epoch 11/20 196/196 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - accuracy: 0.7725 - loss: 0.6550 - val_accuracy: 0.7115 - val_loss: 0.8521 Epoch 12/20 196/196 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.7808 - loss: 0.6302 - val_accuracy: 0.7051 - val_loss: 0.8823 Epoch 13/20 196/196 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.7860 - loss: 0.6101 - val_accuracy: 0.7122 - val_loss: 0.8635 Epoch 14/20 196/196 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.7998 - loss: 0.5786 - val_accuracy: 0.7214 - val_loss: 0.8348 Epoch 15/20 196/196 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.8117 - loss: 0.5473 - val_accuracy: 0.7139 - val_loss: 0.8835 Epoch 16/20 196/196 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.8168 - loss: 0.5267 - val_accuracy: 0.7155 - val_loss: 0.8840 Epoch 17/20 196/196 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.8266 - loss: 0.5022 - val_accuracy: 0.7239 - val_loss: 0.8576 Epoch 18/20 196/196 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.8374 - loss: 0.4750 - val_accuracy: 0.7262 - val_loss: 0.8756 Epoch 19/20 196/196 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.8452 - loss: 0.4505 - val_accuracy: 0.7235 - val_loss: 0.9049 Epoch 20/20 196/196 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - accuracy: 0.8531 - loss: 0.4283 - val_accuracy: 0.7304 - val_loss: 0.8962 </code></pre></div> </div> <hr /> <h2 id="involutional-neural-network">Involutional Neural Network</h2> <div class="codehilite"><pre><span></span><code><span class="c1"># Build the involution model.</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;building the involution model...&quot;</span><span class="p">)</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span> <span class="n">x</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">Involution</span><span class="p">(</span> <span class="n">channel</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">group_number</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">reduction_ratio</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;inv_1&quot;</span> <span class="p">)(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">MaxPooling2D</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">))(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">Involution</span><span class="p">(</span> <span class="n">channel</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">group_number</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">reduction_ratio</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;inv_2&quot;</span> <span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">MaxPooling2D</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">))(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">Involution</span><span class="p">(</span> <span class="n">channel</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">group_number</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">stride</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">reduction_ratio</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;inv_3&quot;</span> <span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Flatten</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">&quot;relu&quot;</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">10</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">inv_model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="o">=</span><span class="p">[</span><span class="n">inputs</span><span class="p">],</span> <span class="n">outputs</span><span class="o">=</span><span class="p">[</span><span class="n">outputs</span><span class="p">],</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;inv_model&quot;</span><span class="p">)</span> <span class="c1"># Compile the mode with the necessary loss function and optimizer.</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;compiling the involution model...&quot;</span><span class="p">)</span> <span class="n">inv_model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">optimizer</span><span class="o">=</span><span class="s2">&quot;adam&quot;</span><span class="p">,</span> <span class="n">loss</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">SparseCategoricalCrossentropy</span><span class="p">(</span><span class="n">from_logits</span><span class="o">=</span><span class="kc">True</span><span class="p">),</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;accuracy&quot;</span><span class="p">],</span> <span class="p">)</span> <span class="c1"># train the model</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;inv model training...&quot;</span><span class="p">)</span> <span class="n">inv_hist</span> <span class="o">=</span> <span class="n">inv_model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">train_ds</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="n">test_ds</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>building the involution model... compiling the involution model... inv model training... Epoch 1/20 196/196 ━━━━━━━━━━━━━━━━━━━━ 9s 25ms/step - accuracy: 0.1369 - loss: 2.2728 - val_accuracy: 0.2716 - val_loss: 2.1041 Epoch 2/20 196/196 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.2922 - loss: 1.9489 - val_accuracy: 0.3478 - val_loss: 1.8275 Epoch 3/20 196/196 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.3477 - loss: 1.8098 - val_accuracy: 0.3782 - val_loss: 1.7435 Epoch 4/20 196/196 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - accuracy: 0.3741 - loss: 1.7420 - val_accuracy: 0.3901 - val_loss: 1.6943 Epoch 5/20 196/196 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.3931 - loss: 1.6942 - val_accuracy: 0.4007 - val_loss: 1.6639 Epoch 6/20 196/196 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.4057 - loss: 1.6622 - val_accuracy: 0.4108 - val_loss: 1.6494 Epoch 7/20 196/196 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - accuracy: 0.4134 - loss: 1.6374 - val_accuracy: 0.4202 - val_loss: 1.6363 Epoch 8/20 196/196 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - accuracy: 0.4200 - loss: 1.6166 - val_accuracy: 0.4312 - val_loss: 1.6062 Epoch 9/20 196/196 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.4286 - loss: 1.5949 - val_accuracy: 0.4316 - val_loss: 1.6018 Epoch 10/20 196/196 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.4346 - loss: 1.5794 - val_accuracy: 0.4346 - val_loss: 1.5963 Epoch 11/20 196/196 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - accuracy: 0.4395 - loss: 1.5641 - val_accuracy: 0.4388 - val_loss: 1.5831 Epoch 12/20 196/196 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.4445 - loss: 1.5502 - val_accuracy: 0.4443 - val_loss: 1.5826 Epoch 13/20 196/196 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - accuracy: 0.4493 - loss: 1.5391 - val_accuracy: 0.4497 - val_loss: 1.5574 Epoch 14/20 196/196 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - accuracy: 0.4528 - loss: 1.5255 - val_accuracy: 0.4547 - val_loss: 1.5433 Epoch 15/20 196/196 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - accuracy: 0.4575 - loss: 1.5148 - val_accuracy: 0.4548 - val_loss: 1.5438 Epoch 16/20 196/196 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - accuracy: 0.4599 - loss: 1.5072 - val_accuracy: 0.4581 - val_loss: 1.5323 Epoch 17/20 196/196 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - accuracy: 0.4664 - loss: 1.4957 - val_accuracy: 0.4598 - val_loss: 1.5321 Epoch 18/20 196/196 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - accuracy: 0.4701 - loss: 1.4863 - val_accuracy: 0.4575 - val_loss: 1.5302 Epoch 19/20 196/196 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - accuracy: 0.4737 - loss: 1.4790 - val_accuracy: 0.4676 - val_loss: 1.5233 Epoch 20/20 196/196 ━━━━━━━━━━━━━━━━━━━━ 1s 6ms/step - accuracy: 0.4771 - loss: 1.4740 - val_accuracy: 0.4719 - val_loss: 1.5096 </code></pre></div> </div> <hr /> <h2 id="comparisons">Comparisons</h2> <p>In this section, we will be looking at both the models and compare a few pointers.</p> <h3 id="parameters">Parameters</h3> <p>One can see that with a similar architecture the parameters in a CNN is much larger than that of an INN (Involutional Neural Network).</p> <div class="codehilite"><pre><span></span><code><span class="n">conv_model</span><span class="o">.</span><span class="n">summary</span><span class="p">()</span> <span class="n">inv_model</span><span class="o">.</span><span class="n">summary</span><span class="p">()</span> </code></pre></div> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold">Model: "sequential_3"</span> </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ ┃<span style="font-weight: bold"> Layer (type) </span>┃<span style="font-weight: bold"> Output Shape </span>┃<span style="font-weight: bold"> Param # </span>┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ │ conv2d_6 (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv2D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">896</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ relu1 (<span style="color: #0087ff; text-decoration-color: #0087ff">ReLU</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ max_pooling2d (<span style="color: #0087ff; text-decoration-color: #0087ff">MaxPooling2D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">16</span>, <span style="color: #00af00; text-decoration-color: #00af00">16</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_7 (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv2D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">16</span>, <span style="color: #00af00; text-decoration-color: #00af00">16</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">18,496</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ relu2 (<span style="color: #0087ff; text-decoration-color: #0087ff">ReLU</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">16</span>, <span style="color: #00af00; text-decoration-color: #00af00">16</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ max_pooling2d_1 (<span style="color: #0087ff; text-decoration-color: #0087ff">MaxPooling2D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">8</span>, <span style="color: #00af00; text-decoration-color: #00af00">8</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_8 (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv2D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">8</span>, <span style="color: #00af00; text-decoration-color: #00af00">8</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">36,928</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ relu3 (<span style="color: #0087ff; text-decoration-color: #0087ff">ReLU</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">8</span>, <span style="color: #00af00; text-decoration-color: #00af00">8</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ flatten (<span style="color: #0087ff; text-decoration-color: #0087ff">Flatten</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">4096</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ dense (<span style="color: #0087ff; text-decoration-color: #0087ff">Dense</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">262,208</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ dense_1 (<span style="color: #0087ff; text-decoration-color: #0087ff">Dense</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">10</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">650</span> │ └─────────────────────────────────┴───────────────────────────┴────────────┘ </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Total params: </span><span style="color: #00af00; text-decoration-color: #00af00">957,536</span> (3.65 MB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">319,178</span> (1.22 MB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Non-trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">0</span> (0.00 B) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Optimizer params: </span><span style="color: #00af00; text-decoration-color: #00af00">638,358</span> (2.44 MB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold">Model: "inv_model"</span> </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ ┃<span style="font-weight: bold"> Layer (type) </span>┃<span style="font-weight: bold"> Output Shape </span>┃<span style="font-weight: bold"> Param # </span>┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ │ input_layer_4 (<span style="color: #0087ff; text-decoration-color: #0087ff">InputLayer</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>, <span style="color: #00af00; text-decoration-color: #00af00">3</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ inv_1 (<span style="color: #0087ff; text-decoration-color: #0087ff">Involution</span>) │ [(<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>, <span style="color: #00af00; text-decoration-color: #00af00">3</span>), │ <span style="color: #00af00; text-decoration-color: #00af00">26</span> │ │ │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>, <span style="color: #00af00; text-decoration-color: #00af00">9</span>, <span style="color: #00af00; text-decoration-color: #00af00">1</span>, <span style="color: #00af00; text-decoration-color: #00af00">1</span>)] │ │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ re_lu_4 (<span style="color: #0087ff; text-decoration-color: #0087ff">ReLU</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>, <span style="color: #00af00; text-decoration-color: #00af00">32</span>, <span style="color: #00af00; text-decoration-color: #00af00">3</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ max_pooling2d_2 (<span style="color: #0087ff; text-decoration-color: #0087ff">MaxPooling2D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">16</span>, <span style="color: #00af00; text-decoration-color: #00af00">16</span>, <span style="color: #00af00; text-decoration-color: #00af00">3</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ inv_2 (<span style="color: #0087ff; text-decoration-color: #0087ff">Involution</span>) │ [(<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">16</span>, <span style="color: #00af00; text-decoration-color: #00af00">16</span>, <span style="color: #00af00; text-decoration-color: #00af00">3</span>), │ <span style="color: #00af00; text-decoration-color: #00af00">26</span> │ │ │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">16</span>, <span style="color: #00af00; text-decoration-color: #00af00">16</span>, <span style="color: #00af00; text-decoration-color: #00af00">9</span>, <span style="color: #00af00; text-decoration-color: #00af00">1</span>, <span style="color: #00af00; text-decoration-color: #00af00">1</span>)] │ │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ re_lu_6 (<span style="color: #0087ff; text-decoration-color: #0087ff">ReLU</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">16</span>, <span style="color: #00af00; text-decoration-color: #00af00">16</span>, <span style="color: #00af00; text-decoration-color: #00af00">3</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ max_pooling2d_3 (<span style="color: #0087ff; text-decoration-color: #0087ff">MaxPooling2D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">8</span>, <span style="color: #00af00; text-decoration-color: #00af00">8</span>, <span style="color: #00af00; text-decoration-color: #00af00">3</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ inv_3 (<span style="color: #0087ff; text-decoration-color: #0087ff">Involution</span>) │ [(<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">8</span>, <span style="color: #00af00; text-decoration-color: #00af00">8</span>, <span style="color: #00af00; text-decoration-color: #00af00">3</span>), (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, │ <span style="color: #00af00; text-decoration-color: #00af00">26</span> │ │ │ <span style="color: #00af00; text-decoration-color: #00af00">8</span>, <span style="color: #00af00; text-decoration-color: #00af00">8</span>, <span style="color: #00af00; text-decoration-color: #00af00">9</span>, <span style="color: #00af00; text-decoration-color: #00af00">1</span>, <span style="color: #00af00; text-decoration-color: #00af00">1</span>)] │ │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ re_lu_8 (<span style="color: #0087ff; text-decoration-color: #0087ff">ReLU</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">8</span>, <span style="color: #00af00; text-decoration-color: #00af00">8</span>, <span style="color: #00af00; text-decoration-color: #00af00">3</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ flatten_1 (<span style="color: #0087ff; text-decoration-color: #0087ff">Flatten</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">192</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ dense_2 (<span style="color: #0087ff; text-decoration-color: #0087ff">Dense</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">12,352</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ dense_3 (<span style="color: #0087ff; text-decoration-color: #0087ff">Dense</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">10</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">650</span> │ └─────────────────────────────────┴───────────────────────────┴────────────┘ </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Total params: </span><span style="color: #00af00; text-decoration-color: #00af00">39,230</span> (153.25 KB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">13,074</span> (51.07 KB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Non-trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">6</span> (24.00 B) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Optimizer params: </span><span style="color: #00af00; text-decoration-color: #00af00">26,150</span> (102.15 KB) </pre> <h3 id="loss-and-accuracy-plots">Loss and Accuracy Plots</h3> <p>Here, the loss and the accuracy plots demonstrate that INNs are slow learners (with lower parameters).</p> <div class="codehilite"><pre><span></span><code><span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">20</span><span class="p">,</span> <span class="mi">5</span><span class="p">))</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="s2">&quot;Convolution Loss&quot;</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">conv_hist</span><span class="o">.</span><span class="n">history</span><span class="p">[</span><span class="s2">&quot;loss&quot;</span><span class="p">],</span> <span class="n">label</span><span class="o">=</span><span class="s2">&quot;loss&quot;</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">conv_hist</span><span class="o">.</span><span class="n">history</span><span class="p">[</span><span class="s2">&quot;val_loss&quot;</span><span class="p">],</span> <span class="n">label</span><span class="o">=</span><span class="s2">&quot;val_loss&quot;</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">legend</span><span class="p">()</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="s2">&quot;Involution Loss&quot;</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">inv_hist</span><span class="o">.</span><span class="n">history</span><span class="p">[</span><span class="s2">&quot;loss&quot;</span><span class="p">],</span> <span class="n">label</span><span class="o">=</span><span class="s2">&quot;loss&quot;</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">inv_hist</span><span class="o">.</span><span class="n">history</span><span class="p">[</span><span class="s2">&quot;val_loss&quot;</span><span class="p">],</span> <span class="n">label</span><span class="o">=</span><span class="s2">&quot;val_loss&quot;</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">legend</span><span class="p">()</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">20</span><span class="p">,</span> <span class="mi">5</span><span class="p">))</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="s2">&quot;Convolution Accuracy&quot;</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">conv_hist</span><span class="o">.</span><span class="n">history</span><span class="p">[</span><span class="s2">&quot;accuracy&quot;</span><span class="p">],</span> <span class="n">label</span><span class="o">=</span><span class="s2">&quot;accuracy&quot;</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">conv_hist</span><span class="o">.</span><span class="n">history</span><span class="p">[</span><span class="s2">&quot;val_accuracy&quot;</span><span class="p">],</span> <span class="n">label</span><span class="o">=</span><span class="s2">&quot;val_accuracy&quot;</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">legend</span><span class="p">()</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="s2">&quot;Involution Accuracy&quot;</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">inv_hist</span><span class="o">.</span><span class="n">history</span><span class="p">[</span><span class="s2">&quot;accuracy&quot;</span><span class="p">],</span> <span class="n">label</span><span class="o">=</span><span class="s2">&quot;accuracy&quot;</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">inv_hist</span><span class="o">.</span><span class="n">history</span><span class="p">[</span><span class="s2">&quot;val_accuracy&quot;</span><span class="p">],</span> <span class="n">label</span><span class="o">=</span><span class="s2">&quot;val_accuracy&quot;</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">legend</span><span class="p">()</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> </code></pre></div> <p><img alt="png" src="/img/examples/vision/involution/involution_22_0.png" /></p> <p><img alt="png" src="/img/examples/vision/involution/involution_22_1.png" /></p> <hr /> <h2 id="visualizing-involution-kernels">Visualizing Involution Kernels</h2> <p>To visualize the kernels, we take the sum of <strong>K×K</strong> values from each involution kernel. <strong>All the representatives at different spatial locations frame the corresponding heat map.</strong></p> <p>The authors mention:</p> <p>"Our proposed involution is reminiscent of self-attention and essentially could become a generalized version of it."</p> <p>With the visualization of the kernel we can indeed obtain an attention map of the image. The learned involution kernels provides attention to individual spatial positions of the input tensor. The <strong>location-specific</strong> property makes involution a generic space of models in which self-attention belongs.</p> <div class="codehilite"><pre><span></span><code><span class="n">layer_names</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;inv_1&quot;</span><span class="p">,</span> <span class="s2">&quot;inv_2&quot;</span><span class="p">,</span> <span class="s2">&quot;inv_3&quot;</span><span class="p">]</span> <span class="n">outputs</span> <span class="o">=</span> <span class="p">[</span><span class="n">inv_model</span><span class="o">.</span><span class="n">get_layer</span><span class="p">(</span><span class="n">name</span><span class="p">)</span><span class="o">.</span><span class="n">output</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="k">for</span> <span class="n">name</span> <span class="ow">in</span> <span class="n">layer_names</span><span class="p">]</span> <span class="n">vis_model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inv_model</span><span class="o">.</span><span class="n">input</span><span class="p">,</span> <span class="n">outputs</span><span class="p">)</span> <span class="n">fig</span><span class="p">,</span> <span class="n">axes</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">(</span><span class="n">nrows</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">ncols</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">30</span><span class="p">))</span> <span class="k">for</span> <span class="n">ax</span><span class="p">,</span> <span class="n">test_image</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">axes</span><span class="p">,</span> <span class="n">test_images</span><span class="p">[:</span><span class="mi">10</span><span class="p">]):</span> <span class="p">(</span><span class="n">inv1_kernel</span><span class="p">,</span> <span class="n">inv2_kernel</span><span class="p">,</span> <span class="n">inv3_kernel</span><span class="p">)</span> <span class="o">=</span> <span class="n">vis_model</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">test_image</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="o">...</span><span class="p">])</span> <span class="n">inv1_kernel</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">reduce_sum</span><span class="p">(</span><span class="n">inv1_kernel</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">2</span><span class="p">,</span> <span class="o">-</span><span class="mi">3</span><span class="p">])</span> <span class="n">inv2_kernel</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">reduce_sum</span><span class="p">(</span><span class="n">inv2_kernel</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">2</span><span class="p">,</span> <span class="o">-</span><span class="mi">3</span><span class="p">])</span> <span class="n">inv3_kernel</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">reduce_sum</span><span class="p">(</span><span class="n">inv3_kernel</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">2</span><span class="p">,</span> <span class="o">-</span><span class="mi">3</span><span class="p">])</span> <span class="n">ax</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">array_to_img</span><span class="p">(</span><span class="n">test_image</span><span class="p">))</span> <span class="n">ax</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">&quot;Input Image&quot;</span><span class="p">)</span> <span class="n">ax</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">array_to_img</span><span class="p">(</span><span class="n">inv1_kernel</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="o">...</span><span class="p">,</span> <span class="kc">None</span><span class="p">]))</span> <span class="n">ax</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">&quot;Involution Kernel 1&quot;</span><span class="p">)</span> <span class="n">ax</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">array_to_img</span><span class="p">(</span><span class="n">inv2_kernel</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="o">...</span><span class="p">,</span> <span class="kc">None</span><span class="p">]))</span> <span class="n">ax</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">&quot;Involution Kernel 2&quot;</span><span class="p">)</span> <span class="n">ax</span><span class="p">[</span><span class="mi">3</span><span class="p">]</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">array_to_img</span><span class="p">(</span><span class="n">inv3_kernel</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="o">...</span><span class="p">,</span> <span class="kc">None</span><span class="p">]))</span> <span class="n">ax</span><span class="p">[</span><span class="mi">3</span><span class="p">]</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">&quot;Involution Kernel 3&quot;</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 503ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 10ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/involution/involution_24_1.png" /></p> <hr /> <h2 id="conclusions">Conclusions</h2> <p>In this example, the main focus was to build an <code>Involution</code> layer which can be easily reused. While our comparisons were based on a specific task, feel free to use the layer for different tasks and report your results.</p> <p>According to me, the key take-away of involution is its relationship with self-attention. The intuition behind location-specific and channel-spefic processing makes sense in a lot of tasks.</p> <p>Moving forward one can:</p> <ul> <li>Look at <a href="https://youtu.be/pH2jZun8MoY">Yannick's video</a> on involution for a better understanding.</li> <li>Experiment with the various hyperparameters of the involution layer.</li> <li>Build different models with the involution layer.</li> <li>Try building a different kernel generation method altogether.</li> </ul> <p>You can use the trained model hosted on <a href="https://huggingface.co/keras-io/involution">Hugging Face Hub</a> and try the demo on <a href="https://huggingface.co/spaces/keras-io/involution">Hugging Face Spaces</a>.</p> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#involutional-neural-networks'>Involutional neural networks</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#setup'>Setup</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#convolution'>Convolution</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#involution'>Involution</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#testing-the-involution-layer'>Testing the Involution layer</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#image-classification'>Image Classification</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#get-the-cifar10-dataset'>Get the CIFAR10 Dataset</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#visualise-the-data'>Visualise the data</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#convolutional-neural-network'>Convolutional Neural Network</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#involutional-neural-network'>Involutional Neural Network</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#comparisons'>Comparisons</a> </div> <div class='k-outline-depth-3'> <a href='#parameters'>Parameters</a> </div> <div class='k-outline-depth-3'> <a href='#loss-and-accuracy-plots'>Loss and Accuracy Plots</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#visualizing-involution-kernels'>Visualizing Involution Kernels</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#conclusions'>Conclusions</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>

Pages: 1 2 3 4 5 6 7 8 9 10