CINXE.COM
Image Classification using Global Context Vision Transformer
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/vision/image_classification_using_global_context_vision_transformer/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Image Classification using Global Context Vision Transformer"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Image Classification using Global Context Vision Transformer"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Image Classification using Global Context Vision Transformer</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink active" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink2" href="/examples/vision/image_classification_from_scratch/">Image classification from scratch</a> <a class="nav-sublink2" href="/examples/vision/mnist_convnet/">Simple MNIST convnet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_efficientnet_fine_tuning/">Image classification via fine-tuning with EfficientNet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_with_vision_transformer/">Image classification with Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/attention_mil_classification/">Classification using Attention-based Deep Multiple Instance Learning</a> <a class="nav-sublink2" href="/examples/vision/mlp_image_classification/">Image classification with modern MLP models</a> <a class="nav-sublink2" href="/examples/vision/mobilevit/">A mobile-friendly Transformer-based model for image classification</a> <a class="nav-sublink2" href="/examples/vision/xray_classification_with_tpus/">Pneumonia Classification on TPU</a> <a class="nav-sublink2" href="/examples/vision/cct/">Compact Convolutional Transformers</a> <a class="nav-sublink2" href="/examples/vision/convmixer/">Image classification with ConvMixer</a> <a class="nav-sublink2" href="/examples/vision/eanet/">Image classification with EANet (External Attention Transformer)</a> <a class="nav-sublink2" href="/examples/vision/involution/">Involutional neural networks</a> <a class="nav-sublink2" href="/examples/vision/perceiver_image_classification/">Image classification with Perceiver</a> <a class="nav-sublink2" href="/examples/vision/reptile/">Few-Shot learning with Reptile</a> <a class="nav-sublink2" href="/examples/vision/semisupervised_simclr/">Semi-supervised image classification using contrastive pretraining with SimCLR</a> <a class="nav-sublink2" href="/examples/vision/swin_transformers/">Image classification with Swin Transformers</a> <a class="nav-sublink2" href="/examples/vision/vit_small_ds/">Train a Vision Transformer on small datasets</a> <a class="nav-sublink2" href="/examples/vision/shiftvit/">A Vision Transformer without Attention</a> <a class="nav-sublink2 active" href="/examples/vision/image_classification_using_global_context_vision_transformer/">Image Classification using Global Context Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/oxford_pets_image_segmentation/">Image segmentation with a U-Net-like architecture</a> <a class="nav-sublink2" href="/examples/vision/deeplabv3_plus/">Multiclass semantic segmentation using DeepLabV3+</a> <a class="nav-sublink2" href="/examples/vision/basnet_segmentation/">Highly accurate boundaries segmentation using BASNet</a> <a class="nav-sublink2" href="/examples/vision/fully_convolutional_network/">Image Segmentation using Composable Fully-Convolutional Networks</a> <a class="nav-sublink2" href="/examples/vision/retinanet/">Object Detection with RetinaNet</a> <a class="nav-sublink2" href="/examples/vision/keypoint_detection/">Keypoint Detection with Transfer Learning</a> <a class="nav-sublink2" href="/examples/vision/object_detection_using_vision_transformer/">Object detection with Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/3D_image_classification/">3D image classification from CT scans</a> <a class="nav-sublink2" href="/examples/vision/depth_estimation/">Monocular depth estimation</a> <a class="nav-sublink2" href="/examples/vision/nerf/">3D volumetric rendering with NeRF</a> <a class="nav-sublink2" href="/examples/vision/pointnet_segmentation/">Point cloud segmentation with PointNet</a> <a class="nav-sublink2" href="/examples/vision/pointnet/">Point cloud classification</a> <a class="nav-sublink2" href="/examples/vision/captcha_ocr/">OCR model for reading Captchas</a> <a class="nav-sublink2" href="/examples/vision/handwriting_recognition/">Handwriting recognition</a> <a class="nav-sublink2" href="/examples/vision/autoencoder/">Convolutional autoencoder for image denoising</a> <a class="nav-sublink2" href="/examples/vision/mirnet/">Low-light image enhancement using MIRNet</a> <a class="nav-sublink2" href="/examples/vision/super_resolution_sub_pixel/">Image Super-Resolution using an Efficient Sub-Pixel CNN</a> <a class="nav-sublink2" href="/examples/vision/edsr/">Enhanced Deep Residual Networks for single-image super-resolution</a> <a class="nav-sublink2" href="/examples/vision/zero_dce/">Zero-DCE for low-light image enhancement</a> <a class="nav-sublink2" href="/examples/vision/cutmix/">CutMix data augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/mixup/">MixUp augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/randaugment/">RandAugment for Image Classification for Improved Robustness</a> <a class="nav-sublink2" href="/examples/vision/image_captioning/">Image captioning</a> <a class="nav-sublink2" href="/examples/vision/nl_image_search/">Natural language image search with a Dual Encoder</a> <a class="nav-sublink2" href="/examples/vision/visualizing_what_convnets_learn/">Visualizing what convnets learn</a> <a class="nav-sublink2" href="/examples/vision/integrated_gradients/">Model interpretability with Integrated Gradients</a> <a class="nav-sublink2" href="/examples/vision/probing_vits/">Investigating Vision Transformer representations</a> <a class="nav-sublink2" href="/examples/vision/grad_cam/">Grad-CAM class activation visualization</a> <a class="nav-sublink2" href="/examples/vision/near_dup_search/">Near-duplicate image search</a> <a class="nav-sublink2" href="/examples/vision/semantic_image_clustering/">Semantic Image Clustering</a> <a class="nav-sublink2" href="/examples/vision/siamese_contrastive/">Image similarity estimation using a Siamese Network with a contrastive loss</a> <a class="nav-sublink2" href="/examples/vision/siamese_network/">Image similarity estimation using a Siamese Network with a triplet loss</a> <a class="nav-sublink2" href="/examples/vision/metric_learning/">Metric learning for image similarity search</a> <a class="nav-sublink2" href="/examples/vision/metric_learning_tf_similarity/">Metric learning for image similarity search using TensorFlow Similarity</a> <a class="nav-sublink2" href="/examples/vision/nnclr/">Self-supervised contrastive learning with NNCLR</a> <a class="nav-sublink2" href="/examples/vision/video_classification/">Video Classification with a CNN-RNN Architecture</a> <a class="nav-sublink2" href="/examples/vision/conv_lstm/">Next-Frame Video Prediction with Convolutional LSTMs</a> <a class="nav-sublink2" href="/examples/vision/video_transformers/">Video Classification with Transformers</a> <a class="nav-sublink2" href="/examples/vision/vivit/">Video Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/bit/">Image Classification using BigTransfer (BiT)</a> <a class="nav-sublink2" href="/examples/vision/gradient_centralization/">Gradient Centralization for Better Training Performance</a> <a class="nav-sublink2" href="/examples/vision/token_learner/">Learning to tokenize in Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/knowledge_distillation/">Knowledge Distillation</a> <a class="nav-sublink2" href="/examples/vision/fixres/">FixRes: Fixing train-test resolution discrepancy</a> <a class="nav-sublink2" href="/examples/vision/cait/">Class Attention Image Transformers with LayerScale</a> <a class="nav-sublink2" href="/examples/vision/patch_convnet/">Augmenting convnets with aggregated attention</a> <a class="nav-sublink2" href="/examples/vision/learnable_resizer/">Learning to Resize</a> <a class="nav-sublink2" href="/examples/vision/adamatch/">Semi-supervision and domain adaptation with AdaMatch</a> <a class="nav-sublink2" href="/examples/vision/barlow_twins/">Barlow Twins for Contrastive SSL</a> <a class="nav-sublink2" href="/examples/vision/consistency_training/">Consistency training with supervision</a> <a class="nav-sublink2" href="/examples/vision/deit/">Distilling Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/focal_modulation_network/">Focal Modulation: A replacement for Self-Attention</a> <a class="nav-sublink2" href="/examples/vision/forwardforward/">Using the Forward-Forward Algorithm for Image Classification</a> <a class="nav-sublink2" href="/examples/vision/masked_image_modeling/">Masked image modeling with Autoencoders</a> <a class="nav-sublink2" href="/examples/vision/sam/">Segment Anything Model with 🤗Transformers</a> <a class="nav-sublink2" href="/examples/vision/segformer/">Semantic segmentation with SegFormer and Hugging Face Transformers</a> <a class="nav-sublink2" href="/examples/vision/simsiam/">Self-supervised contrastive learning with SimSiam</a> <a class="nav-sublink2" href="/examples/vision/supervised-contrastive-learning/">Supervised Contrastive Learning</a> <a class="nav-sublink2" href="/examples/vision/temporal_latent_bottleneck/">When Recurrence meets Transformers</a> <a class="nav-sublink2" href="/examples/vision/yolov8/">Efficient Object Detection with YOLOV8 and KerasCV</a> <a class="nav-sublink" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparam Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/vision/'>Computer Vision</a> / Image Classification using Global Context Vision Transformer </div> <div class='k-content'> <h1 id="image-classification-using-global-context-vision-transformer">Image Classification using Global Context Vision Transformer</h1> <p><strong>Author:</strong> Md Awsafur Rahman<br> <strong>Date created:</strong> 2023/10/30<br> <strong>Last modified:</strong> 2023/10/30<br> <strong>Description:</strong> Implementation and fine-tuning of Global Context Vision Transformer for image classification.</p> <div class='example_version_banner keras_3'>ⓘ This example uses Keras 3</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/vision/ipynb/image_classification_using_global_context_vision_transformer.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/examples/vision/image_classification_using_global_context_vision_transformer.py"><strong>GitHub source</strong></a></p> <h1 id="setup">Setup</h1> <div class="codehilite"><pre><span></span><code><span class="err">!</span><span class="n">pip</span> <span class="n">install</span> <span class="o">--</span><span class="n">upgrade</span> <span class="n">keras_cv</span> <span class="n">tensorflow</span> <span class="err">!</span><span class="n">pip</span> <span class="n">install</span> <span class="o">--</span><span class="n">upgrade</span> <span class="n">keras</span> </code></pre></div> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">keras</span> <span class="kn">from</span> <span class="nn">keras_cv.layers</span> <span class="kn">import</span> <span class="n">DropPath</span> <span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">ops</span> <span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">layers</span> <span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="nn">tf</span> <span class="c1"># only for dataloader</span> <span class="kn">import</span> <span class="nn">tensorflow_datasets</span> <span class="k">as</span> <span class="nn">tfds</span> <span class="c1"># for flower dataset</span> <span class="kn">from</span> <span class="nn">skimage.data</span> <span class="kn">import</span> <span class="n">chelsea</span> <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span> <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> </code></pre></div> <hr /> <h2 id="introduction">Introduction</h2> <p>In this notebook, we will utilize multi-backend Keras 3.0 to implement the <a href="https://arxiv.org/abs/2206.09959"><strong>GCViT: Global Context Vision Transformer</strong></a> paper, presented at ICML 2023 by A Hatamizadeh et al. The, we will fine-tune the model on the Flower dataset for image classification task, leveraging the official ImageNet pre-trained weights. A highlight of this notebook is its compatibility with multiple backends: TensorFlow, PyTorch, and JAX, showcasing the true potential of multi-backend Keras.</p> <hr /> <h2 id="motivation">Motivation</h2> <blockquote> <p><strong>Note:</strong> In this section we'll learn about the backstory of GCViT and try to understand why it is proposed.</p> </blockquote> <ul> <li>During recent years, <strong>Transformers</strong> have achieved dominance in <strong>Natural Language Processing (NLP)</strong> tasks and with the <strong>self-attention</strong> mechanism which allows for capturing both long and short-range information.</li> <li>Following this trend, <strong>Vision Transformer (ViT)</strong> proposed to utilize image patches as tokens in a gigantic architecture similar to encoder of the original Transformer.</li> <li>Despite the historic dominance of <strong>Convolutional Neural Network (CNN)</strong> in computer vision, <strong>ViT-based</strong> models have shown <strong>SOTA or competitive performance</strong> in various computer vision tasks.</li> </ul> <p><img src="https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/vit_gif.gif" width=600></p> <ul> <li>However, <strong>quadratic [<code>O(n^2)</code>] computational complexity</strong> of self-attention and <strong>lack of multi-scale information</strong> makes it difficult for <strong>ViT</strong> to be considered as general-purpose architecture for Compute Vision tasks like <strong>segmentation and object detection</strong> where it requires <strong>dense prediction at the pixel level</strong>.</li> <li>Swin Transformer has attempted to address the issues of <strong>ViT</strong> by proposing <strong>multi-resolution/hierarchical</strong> architectures in which the self-attention is computed in <strong>local windows</strong> and cross-window connections such as <strong>window shifting</strong> are used for modeling the interactions across different regions. But the <strong>limited receptive field of local windows</strong> can not capture long-range information, and cross-window-connection schemes such as <strong>window-shifting only cover a small neighborhood</strong> in the vicinity of each window. Also, it lacks <strong>inductive-bias</strong> that encourages certain translation invariance is still preferable for general-purpose visual modeling, particularly for the dense prediction tasks of object detection and semantic segmentation.</li> </ul> <p><img src="https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/swin_vs_vit.JPG" width=400> <img src="https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/shifted_window.JPG" width=400> <img src="https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/swin_arch.JPG" width=800></p> <ul> <li>To address above limitations, <strong>Global Context (GC) ViT</strong> network is proposed.</li> </ul> <hr /> <h2 id="architecture">Architecture</h2> <p>Let's have a quick <strong>overview</strong> of our key components, 1. <code>Stem/PatchEmbed:</code> A stem/patchify layer processes images at the network’s beginning. For this network, it creates <strong>patches/tokens</strong> and converts them into <strong>embeddings</strong>. 2. <code>Level:</code> It is the repetitive building block that extracts features using different blocks. 3. <code>Global Token Gen./FeatureExtraction:</code> It generates <strong>global tokens/patches</strong> with <strong>Depthwise-CNN</strong>, <strong>SqueezeAndExcitation (Squeeze-Excitation)</strong>, <strong>CNN</strong> and <strong>MaxPooling</strong>. So basically it's a Feature Extractor. 4. <code>Block:</code> It is the repetitive module that applies attention to the features and projects them to a certain dimension. 1. <code>Local-MSA:</code> Local Multi head Self Attention. 2. <code>Global-MSA:</code> Global Multi head Self Attention. 3. <code>MLP:</code> Linear layer that projects a vector to another dimension. 5. <code>Downsample/ReduceSize:</code> It is very similar to <strong>Global Token Gen.</strong> module except it uses <strong>CNN</strong> instead of <strong>MaxPooling</strong> to downsample with additional <strong>Layer Normalization</strong> modules. 6. <code>Head:</code> It is the module responsible for the classification task. 1. <code>Pooling:</code> It converts <code>N x 2D</code> features to <code>N x 1D</code> features. 2. <code>Classifier:</code> It processes <code>N x 1D</code> features to make a decision about class.</p> <p>I've annotated the architecture figure to make it easier to digest, <img src="https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/arch_annot.png"></p> <h3 id="unit-blocks">Unit Blocks</h3> <blockquote> <p><strong>Note:</strong> This blocks are used to build other modules throughout the paper. Most of the blocks are either borrowed from other work or modified version old work.</p> </blockquote> <ol> <li> <p><code>SqueezeAndExcitation</code>: <strong>Squeeze-Excitation (SE)</strong> aka <strong>Bottleneck</strong> module acts sd kind of <strong>channel attention</strong>. It consits of <strong>AvgPooling</strong>, <strong>Dense/FullyConnected (FC)/Linear</strong> , <strong>GELU</strong> and <strong>Sigmoid</strong> module. <img src="https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/se_annot.png" width=400></p> </li> <li> <p><code>Fused-MBConv:</code> This is similar to the one used in <strong>EfficientNetV2</strong>. It uses <strong>Depthwise-Conv</strong>, <strong>GELU</strong>, <strong>SqueezeAndExcitation</strong>, <strong>Conv</strong>, to extract feature with a resiudal connection. Note that, no new module is declared for this one, we simply applied corresponding modules directly. <img src="https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/fmb_annot.png" width=350></p> </li> <li> <p><code>ReduceSize</code>: It is a <strong>CNN</strong> based <strong>downsample</strong> module which abvobe mentioned <code>Fused-MBConv</code> module to extract feature, <strong>Strided Conv</strong> to simultaneously reduce spatial dimension and increse channelwise dimention of the features and finally <strong>LayerNormalization</strong> module to normalize features. In the paper/figure this module is referred as <strong>downsample</strong> module. I think it is mention worthy that <strong>SwniTransformer</strong> used <code>PatchMerging</code> module instead of <code>ReduceSize</code> to reduce the spatial dimention and increase channelwise dimension which uses <strong>fully-connected/dense/linear</strong> module. According to the <strong>GCViT</strong> paper, one of the purposes of using <code>ReduceSize</code> is to add inductive bias through <strong>CNN</strong> module. <img src="https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/down_annot.png" width=300></p> </li> <li> <p><code>MLP:</code> This is our very own <strong>Multi Layer Perceptron</strong> module. This a feed-forward/fully-connected/linear module which simply projects input to an arbitary dimension.</p> </li> </ol> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">SqueezeAndExcitation</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Squeeze and excitation block.</span> <span class="sd"> Args:</span> <span class="sd"> output_dim: output features dimension, if `None` use same dim as input.</span> <span class="sd"> expansion: expansion ratio.</span> <span class="sd"> """</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">output_dim</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">expansion</span><span class="o">=</span><span class="mf">0.25</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">expansion</span> <span class="o">=</span> <span class="n">expansion</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_dim</span> <span class="o">=</span> <span class="n">output_dim</span> <span class="k">def</span> <span class="nf">build</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_shape</span><span class="p">):</span> <span class="n">inp</span> <span class="o">=</span> <span class="n">input_shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_dim</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_dim</span> <span class="ow">or</span> <span class="n">inp</span> <span class="bp">self</span><span class="o">.</span><span class="n">avg_pool</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">GlobalAvgPool2D</span><span class="p">(</span><span class="n">keepdims</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"avg_pool"</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc</span> <span class="o">=</span> <span class="p">[</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="nb">int</span><span class="p">(</span><span class="n">inp</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">expansion</span><span class="p">),</span> <span class="n">use_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"fc_0"</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Activation</span><span class="p">(</span><span class="s2">"gelu"</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"fc_1"</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">output_dim</span><span class="p">,</span> <span class="n">use_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"fc_2"</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Activation</span><span class="p">(</span><span class="s2">"sigmoid"</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"fc_3"</span><span class="p">),</span> <span class="p">]</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">build</span><span class="p">(</span><span class="n">input_shape</span><span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">avg_pool</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc</span><span class="p">:</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">x</span> <span class="o">*</span> <span class="n">inputs</span> <span class="k">class</span> <span class="nc">ReduceSize</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Down-sampling block.</span> <span class="sd"> Args:</span> <span class="sd"> keepdims: if False spatial dim is reduced and channel dim is increased</span> <span class="sd"> """</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">keepdims</span> <span class="o">=</span> <span class="n">keepdims</span> <span class="k">def</span> <span class="nf">build</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_shape</span><span class="p">):</span> <span class="n">embed_dim</span> <span class="o">=</span> <span class="n">input_shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="n">dim_out</span> <span class="o">=</span> <span class="n">embed_dim</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">keepdims</span> <span class="k">else</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">embed_dim</span> <span class="bp">self</span><span class="o">.</span><span class="n">pad1</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">ZeroPadding2D</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"pad1"</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">pad2</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">ZeroPadding2D</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"pad2"</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv</span> <span class="o">=</span> <span class="p">[</span> <span class="n">layers</span><span class="o">.</span><span class="n">DepthwiseConv2D</span><span class="p">(</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">strides</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"valid"</span><span class="p">,</span> <span class="n">use_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"conv_0"</span> <span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Activation</span><span class="p">(</span><span class="s2">"gelu"</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"conv_1"</span><span class="p">),</span> <span class="n">SqueezeAndExcitation</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">"conv_2"</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span> <span class="n">embed_dim</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">strides</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"valid"</span><span class="p">,</span> <span class="n">use_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"conv_3"</span><span class="p">,</span> <span class="p">),</span> <span class="p">]</span> <span class="bp">self</span><span class="o">.</span><span class="n">reduction</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span> <span class="n">dim_out</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">strides</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"valid"</span><span class="p">,</span> <span class="n">use_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"reduction"</span><span class="p">,</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm1</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">LayerNormalization</span><span class="p">(</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mf">1e-05</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"norm1"</span> <span class="p">)</span> <span class="c1"># eps like PyTorch</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm2</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">LayerNormalization</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mf">1e-05</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"norm2"</span><span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm1</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">xr</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pad1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv</span><span class="p">:</span> <span class="n">xr</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">xr</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">xr</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pad2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">reduction</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">x</span> <span class="k">class</span> <span class="nc">MLP</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Multi-Layer Perceptron (MLP) block.</span> <span class="sd"> Args:</span> <span class="sd"> hidden_features: hidden features dimension.</span> <span class="sd"> out_features: output features dimension.</span> <span class="sd"> activation: activation function.</span> <span class="sd"> dropout: dropout rate.</span> <span class="sd"> """</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> <span class="bp">self</span><span class="p">,</span> <span class="n">hidden_features</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">out_features</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"gelu"</span><span class="p">,</span> <span class="n">dropout</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">,</span> <span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_features</span> <span class="o">=</span> <span class="n">hidden_features</span> <span class="bp">self</span><span class="o">.</span><span class="n">out_features</span> <span class="o">=</span> <span class="n">out_features</span> <span class="bp">self</span><span class="o">.</span><span class="n">activation</span> <span class="o">=</span> <span class="n">activation</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">dropout</span> <span class="k">def</span> <span class="nf">build</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_shape</span><span class="p">):</span> <span class="bp">self</span><span class="o">.</span><span class="n">in_features</span> <span class="o">=</span> <span class="n">input_shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_features</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_features</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">in_features</span> <span class="bp">self</span><span class="o">.</span><span class="n">out_features</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">out_features</span> <span class="ow">or</span> <span class="bp">self</span><span class="o">.</span><span class="n">in_features</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc1</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">hidden_features</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"fc1"</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">act</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Activation</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">activation</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"act"</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc2</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">out_features</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"fc2"</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">drop1</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"drop1"</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">drop2</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"drop2"</span><span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc1</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">act</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">drop1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">fc2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">drop2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">x</span> </code></pre></div> <h3 id="stem">Stem</h3> <blockquote> <p><strong>Notes</strong>: In the code, this module is referred to as <strong>PatchEmbed</strong> but on paper, it is referred to as <strong>Stem</strong>.</p> </blockquote> <p>In the model, we have first used <code>patch_embed</code> module. Let's try to understand this module. As we can see from the <code>call</code> method, 1. This module first <strong>pads</strong> input 2. Then uses <strong>convolutions</strong> to extract patches with embeddings. 3. Finally, uses <code>ReduceSize</code> module to first extract features with <strong>convolution</strong> but neither reduces spatial dimension nor increases spatial dimension. 4. One important point to notice, unlike <strong>ViT</strong> or <strong>SwinTransformer</strong>, <strong>GCViT</strong> creates <strong>overlapping patches</strong>. We can notice that from the code, <code>Conv2D(self.embed_dim, kernel_size=3, strides=2, name='proj')</code>. If we wanted <strong>non-overlapping</strong> patches then we would've used the same <code>kernel_size</code> and <code>stride</code>. 5. This module reduces the spatial dimension of input by <code>4x</code>.</p> <blockquote> <p>Summary: image → padding → convolution → (feature_extract + downsample)</p> </blockquote> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">PatchEmbed</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Patch embedding block.</span> <span class="sd"> Args:</span> <span class="sd"> embed_dim: feature size dimension.</span> <span class="sd"> """</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">embed_dim</span> <span class="o">=</span> <span class="n">embed_dim</span> <span class="k">def</span> <span class="nf">build</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_shape</span><span class="p">):</span> <span class="bp">self</span><span class="o">.</span><span class="n">pad</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">ZeroPadding2D</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"pad"</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">proj</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">embed_dim</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"proj"</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv_down</span> <span class="o">=</span> <span class="n">ReduceSize</span><span class="p">(</span><span class="n">keepdims</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"conv_down"</span><span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pad</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">proj</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv_down</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">x</span> </code></pre></div> <h3 id="global-token-gen">Global Token Gen.</h3> <blockquote> <p><strong>Notes:</strong> It is one of the two <strong>CNN</strong> modules that is used to imppose inductive bias.</p> </blockquote> <p>As we can see from above cell, in the <code>level</code> we have first used <code>to_q_global/Global Token Gen./FeatureExtraction</code>. Let's try to understand how it works,</p> <ul> <li>This module is series of <code>FeatureExtract</code> module, according to paper we need to repeat this module <code>K</code> times, where <code>K = log2(H/h)</code>, <code>H = feature_map_height</code>, <code>W = feature_map_width</code>.</li> <li><code>FeatureExtraction:</code> This layer is very similar to <code>ReduceSize</code> module except it uses <strong>MaxPooling</strong> module to reduce the dimension, it doesn't increse feature dimension (channelsie) and it doesn't uses <strong>LayerNormalizaton</strong>. This module is used to in <code>Generate Token Gen.</code> module repeatedly to generte <strong>global tokens</strong> for <strong>global-context-attention</strong>.</li> <li>One important point to notice from the figure is that, <strong>global tokens</strong> is shared across the whole image which means we use only <strong>one global window</strong> for <strong>all local tokens</strong> in a image. This makes the computation very efficient.</li> <li>For input feature map with shape <code>(B, H, W, C)</code>, we'll get output shape <code>(B, h, w, C)</code>. If we copy these global tokens for total <code>M</code> local windows in an image where, <code>M = (H x W)/(h x w) = num_window</code>, then output shape: <code>(B * M, h, w, C)</code>."</li> </ul> <blockquote> <p>Summary: This module is used to <code>resize</code> the image to fit window.</p> </blockquote> <p><img src="https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/global_token_annot.png" width=800></p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">FeatureExtraction</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Feature extraction block.</span> <span class="sd"> Args:</span> <span class="sd"> keepdims: bool argument for maintaining the resolution.</span> <span class="sd"> """</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">keepdims</span> <span class="o">=</span> <span class="n">keepdims</span> <span class="k">def</span> <span class="nf">build</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_shape</span><span class="p">):</span> <span class="n">embed_dim</span> <span class="o">=</span> <span class="n">input_shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="bp">self</span><span class="o">.</span><span class="n">pad1</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">ZeroPadding2D</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"pad1"</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">pad2</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">ZeroPadding2D</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"pad2"</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv</span> <span class="o">=</span> <span class="p">[</span> <span class="n">layers</span><span class="o">.</span><span class="n">DepthwiseConv2D</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">use_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"conv_0"</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Activation</span><span class="p">(</span><span class="s2">"gelu"</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"conv_1"</span><span class="p">),</span> <span class="n">SqueezeAndExcitation</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">"conv_2"</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="n">embed_dim</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">use_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"conv_3"</span><span class="p">),</span> <span class="p">]</span> <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">keepdims</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">pool</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">MaxPool2D</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"pool"</span><span class="p">)</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">build</span><span class="p">(</span><span class="n">input_shape</span><span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="n">inputs</span> <span class="n">xr</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pad1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">conv</span><span class="p">:</span> <span class="n">xr</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">xr</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">xr</span> <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">keepdims</span><span class="p">:</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pool</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">pad2</span><span class="p">(</span><span class="n">x</span><span class="p">))</span> <span class="k">return</span> <span class="n">x</span> <span class="k">class</span> <span class="nc">GlobalQueryGenerator</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Global query generator.</span> <span class="sd"> Args:</span> <span class="sd"> keepdims: to keep the dimension of FeatureExtraction layer.</span> <span class="sd"> For instance, repeating log(56/7) = 3 blocks, with input</span> <span class="sd"> window dimension 56 and output window dimension 7 at down-sampling</span> <span class="sd"> ratio 2. Please check Fig.5 of GC ViT paper for details.</span> <span class="sd"> """</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">keepdims</span> <span class="o">=</span> <span class="n">keepdims</span> <span class="k">def</span> <span class="nf">build</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_shape</span><span class="p">):</span> <span class="bp">self</span><span class="o">.</span><span class="n">to_q_global</span> <span class="o">=</span> <span class="p">[</span> <span class="n">FeatureExtraction</span><span class="p">(</span><span class="n">keepdims</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="sa">f</span><span class="s2">"to_q_global_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">keepdims</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">keepdims</span><span class="p">)</span> <span class="p">]</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">build</span><span class="p">(</span><span class="n">input_shape</span><span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="n">inputs</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">to_q_global</span><span class="p">:</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">x</span> </code></pre></div> <h3 id="attention">Attention</h3> <blockquote> <p><strong>Notes:</strong> This is the core contribution of the paper.</p> </blockquote> <p>As we can see from the <code>call</code> method, 1. <code>WindowAttention</code> module applies both <strong>local</strong> and <strong>global</strong> window attention depending on <code>global_query</code> parameter.</p> <ol> <li>First it converts input features into <code>query, key, value</code> for local attention and <code>key, value</code> for global attention. For global attention, it takes global query from <code>Global Token Gen.</code>. One thing to notice from the code is that we divide the <strong>features or embed_dim</strong> among all the <strong>heads of Transformer</strong> to reduce the computation. <code>qkv = tf.reshape(qkv, [B_, N, self.qkv_size, self.num_heads, C // self.num_heads])</code></li> <li>Before sending query, key and value for attention, <strong>global token</strong> goes through an important process. Same global tokens or one global window gets copied for all the local windows to increase efficiency. <code>q_global = tf.repeat(q_global, repeats=B_//B, axis=0)</code>, here <code>B_//B</code> means <code>num_windows</code> in a image.</li> <li>Then simply applies <code>local-window-self-attention</code> or <code>global-window-attention</code> depending on <code>global_query</code> parameter. One thing to notice from the code is that we are adding <strong>relative-positional-embedding</strong> with the <strong>attention mask</strong> instead of the <strong>patch embedding</strong>. <code>attn = attn + relative_position_bias[tf.newaxis,]</code> <img src="https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/lvg_msa.PNG" width=800></li> <li>Now, let's think for a bit and try to understand what is happening here. Let's focus on the figure below. We can see from the left, that in the <strong>local-attention</strong> the <strong>query is local</strong> and it's <strong>limited to the local window</strong> (red square border) hence we don't have access to long-range information. But on the right that due to <strong>global query</strong> we're now <strong>not limited to local-windows</strong> (blue square border) and we have access to long-range information. <img src="https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/lvg_arch.PNG" width=800></li> <li>In <strong>ViT</strong> we compare (attention) image-tokens with image-tokens, in <strong>SwinTransformer</strong> we compare window-tokens with window-tokens but in <strong>GCViT</strong> we compare image-tokens with window-tokens. But now you may ask, how can compare(attention) image-tokens with window-tokens even after image-tokens have larger dimensions than window-tokens? (from above figure image-tokens have shape <code>(1, 8, 8, 3)</code> and window-tokens have shape <code>(1, 4, 4, 3)</code>). Yes, you are right we can't directly compare them hence we resize image-tokens to fit window-tokens with <code>Global Token Gen./FeatureExtraction</code> <strong>CNN</strong> module. The following table should give you a clear comparison,</li> </ol> <table> <thead> <tr> <th>Model</th> <th>Query Tokens</th> <th>Key-Value Tokens</th> <th>Attention Type</th> <th>Attention Coverage</th> </tr> </thead> <tbody> <tr> <td>ViT</td> <td>image</td> <td>image</td> <td>self-attention</td> <td>global</td> </tr> <tr> <td>SwinTransformer</td> <td>window</td> <td>window</td> <td>self-attention</td> <td>local</td> </tr> <tr> <td><strong>GCViT</strong></td> <td><strong>resized-image</strong></td> <td><strong>window</strong></td> <td><strong>image-window attention</strong></td> <td><strong>global</strong></td> </tr> </tbody> </table> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">WindowAttention</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Local window attention.</span> <span class="sd"> This implementation was proposed by</span> <span class="sd"> [Liu et al., 2021](https://arxiv.org/abs/2103.14030) in SwinTransformer.</span> <span class="sd"> Args:</span> <span class="sd"> window_size: window size.</span> <span class="sd"> num_heads: number of attention head.</span> <span class="sd"> global_query: if the input contains global_query</span> <span class="sd"> qkv_bias: bool argument for query, key, value learnable bias.</span> <span class="sd"> qk_scale: bool argument to scaling query, key.</span> <span class="sd"> attention_dropout: attention dropout rate.</span> <span class="sd"> projection_dropout: output dropout rate.</span> <span class="sd"> """</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> <span class="bp">self</span><span class="p">,</span> <span class="n">window_size</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">,</span> <span class="n">global_query</span><span class="p">,</span> <span class="n">qkv_bias</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">qk_scale</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">attention_dropout</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">projection_dropout</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">,</span> <span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="n">window_size</span> <span class="o">=</span> <span class="p">(</span><span class="n">window_size</span><span class="p">,</span> <span class="n">window_size</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">window_size</span> <span class="o">=</span> <span class="n">window_size</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_heads</span> <span class="o">=</span> <span class="n">num_heads</span> <span class="bp">self</span><span class="o">.</span><span class="n">global_query</span> <span class="o">=</span> <span class="n">global_query</span> <span class="bp">self</span><span class="o">.</span><span class="n">qkv_bias</span> <span class="o">=</span> <span class="n">qkv_bias</span> <span class="bp">self</span><span class="o">.</span><span class="n">qk_scale</span> <span class="o">=</span> <span class="n">qk_scale</span> <span class="bp">self</span><span class="o">.</span><span class="n">attention_dropout</span> <span class="o">=</span> <span class="n">attention_dropout</span> <span class="bp">self</span><span class="o">.</span><span class="n">projection_dropout</span> <span class="o">=</span> <span class="n">projection_dropout</span> <span class="k">def</span> <span class="nf">build</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_shape</span><span class="p">):</span> <span class="n">embed_dim</span> <span class="o">=</span> <span class="n">input_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="n">head_dim</span> <span class="o">=</span> <span class="n">embed_dim</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_heads</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">qk_scale</span> <span class="ow">or</span> <span class="n">head_dim</span><span class="o">**-</span><span class="mf">0.5</span> <span class="bp">self</span><span class="o">.</span><span class="n">qkv_size</span> <span class="o">=</span> <span class="mi">3</span> <span class="o">-</span> <span class="nb">int</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">global_query</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">qkv</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span> <span class="n">embed_dim</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">qkv_size</span><span class="p">,</span> <span class="n">use_bias</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">qkv_bias</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"qkv"</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">relative_position_bias_table</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">add_weight</span><span class="p">(</span> <span class="n">name</span><span class="o">=</span><span class="s2">"relative_position_bias_table"</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">[</span> <span class="p">(</span><span class="mi">2</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">window_size</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="mi">2</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">window_size</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">-</span> <span class="mi">1</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_heads</span><span class="p">,</span> <span class="p">],</span> <span class="n">initializer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">initializers</span><span class="o">.</span><span class="n">TruncatedNormal</span><span class="p">(</span><span class="n">stddev</span><span class="o">=</span><span class="mf">0.02</span><span class="p">),</span> <span class="n">trainable</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn_drop</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">attention_dropout</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"attn_drop"</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">proj</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"proj"</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">proj_drop</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">projection_dropout</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"proj_drop"</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Activation</span><span class="p">(</span><span class="s2">"softmax"</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"softmax"</span><span class="p">)</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">build</span><span class="p">(</span><span class="n">input_shape</span><span class="p">)</span> <span class="k">def</span> <span class="nf">get_relative_position_index</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="n">coords_h</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">window_size</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="n">coords_w</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">window_size</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span> <span class="n">coords</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">ops</span><span class="o">.</span><span class="n">meshgrid</span><span class="p">(</span><span class="n">coords_h</span><span class="p">,</span> <span class="n">coords_w</span><span class="p">,</span> <span class="n">indexing</span><span class="o">=</span><span class="s2">"ij"</span><span class="p">),</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="n">coords_flatten</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">coords</span><span class="p">,</span> <span class="p">[</span><span class="mi">2</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">])</span> <span class="n">relative_coords</span> <span class="o">=</span> <span class="n">coords_flatten</span><span class="p">[:,</span> <span class="p">:,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">-</span> <span class="n">coords_flatten</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="n">relative_coords</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">relative_coords</span><span class="p">,</span> <span class="n">axes</span><span class="o">=</span><span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">])</span> <span class="n">relative_coords_xx</span> <span class="o">=</span> <span class="n">relative_coords</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">window_size</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">-</span> <span class="mi">1</span> <span class="n">relative_coords_yy</span> <span class="o">=</span> <span class="n">relative_coords</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">window_size</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">-</span> <span class="mi">1</span> <span class="n">relative_coords_xx</span> <span class="o">=</span> <span class="n">relative_coords_xx</span> <span class="o">*</span> <span class="p">(</span><span class="mi">2</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">window_size</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="n">relative_position_index</span> <span class="o">=</span> <span class="n">relative_coords_xx</span> <span class="o">+</span> <span class="n">relative_coords_yy</span> <span class="k">return</span> <span class="n">relative_position_index</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">global_query</span><span class="p">:</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">q_global</span> <span class="o">=</span> <span class="n">inputs</span> <span class="n">B</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">q_global</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span> <span class="c1"># B, N, C</span> <span class="k">else</span><span class="p">:</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="n">B_</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">C</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="c1"># B*num_window, num_tokens, channels</span> <span class="n">qkv</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">qkv</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">qkv</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span> <span class="n">qkv</span><span class="p">,</span> <span class="p">[</span><span class="n">B_</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">qkv_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_heads</span><span class="p">,</span> <span class="n">C</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_heads</span><span class="p">]</span> <span class="p">)</span> <span class="n">qkv</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">qkv</span><span class="p">,</span> <span class="p">[</span><span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">4</span><span class="p">])</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">global_query</span><span class="p">:</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">split</span><span class="p">(</span> <span class="n">qkv</span><span class="p">,</span> <span class="n">indices_or_sections</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span> <span class="p">)</span> <span class="c1"># for unknown shame num=None will throw error</span> <span class="n">q_global</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">repeat</span><span class="p">(</span> <span class="n">q_global</span><span class="p">,</span> <span class="n">repeats</span><span class="o">=</span><span class="n">B_</span> <span class="o">//</span> <span class="n">B</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span> <span class="p">)</span> <span class="c1"># num_windows = B_//B => q_global same for all windows in a img</span> <span class="n">q</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">q_global</span><span class="p">,</span> <span class="p">[</span><span class="n">B_</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_heads</span><span class="p">,</span> <span class="n">C</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_heads</span><span class="p">])</span> <span class="n">q</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">axes</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">])</span> <span class="k">else</span><span class="p">:</span> <span class="n">q</span><span class="p">,</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">qkv</span><span class="p">,</span> <span class="n">indices_or_sections</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="n">q</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="n">q</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="n">k</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="n">v</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="n">q</span> <span class="o">=</span> <span class="n">q</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">scale</span> <span class="n">attn</span> <span class="o">=</span> <span class="n">q</span> <span class="o">@</span> <span class="n">ops</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">axes</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span> <span class="n">relative_position_bias</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">take</span><span class="p">(</span> <span class="bp">self</span><span class="o">.</span><span class="n">relative_position_bias_table</span><span class="p">,</span> <span class="n">ops</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">get_relative_position_index</span><span class="p">(),</span> <span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]),</span> <span class="p">)</span> <span class="n">relative_position_bias</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span> <span class="n">relative_position_bias</span><span class="p">,</span> <span class="p">[</span> <span class="bp">self</span><span class="o">.</span><span class="n">window_size</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">window_size</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">window_size</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">window_size</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="p">],</span> <span class="p">)</span> <span class="n">relative_position_bias</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">relative_position_bias</span><span class="p">,</span> <span class="n">axes</span><span class="o">=</span><span class="p">[</span><span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">])</span> <span class="n">attn</span> <span class="o">=</span> <span class="n">attn</span> <span class="o">+</span> <span class="n">relative_position_bias</span><span class="p">[</span><span class="kc">None</span><span class="p">,]</span> <span class="n">attn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="n">attn</span><span class="p">)</span> <span class="n">attn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn_drop</span><span class="p">(</span><span class="n">attn</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">transpose</span><span class="p">((</span><span class="n">attn</span> <span class="o">@</span> <span class="n">v</span><span class="p">),</span> <span class="n">axes</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">])</span> <span class="n">x</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="p">[</span><span class="n">B_</span><span class="p">,</span> <span class="n">N</span><span class="p">,</span> <span class="n">C</span><span class="p">])</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">proj_drop</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">proj</span><span class="p">(</span><span class="n">x</span><span class="p">))</span> <span class="k">return</span> <span class="n">x</span> </code></pre></div> <h3 id="block">Block</h3> <blockquote> <p><strong>Notes:</strong> This module doesn't have any Convolutional module.</p> </blockquote> <p>In the <code>level</code> second module that we have used is <code>block</code>. Let's try to understand how it works. As we can see from the <code>call</code> method, 1. <code>Block</code> module takes either only feature_maps for local attention or additional global query for global attention. 2. Before sending feature maps for attention, this module converts <strong>batch feature maps</strong> to <strong>batch windows</strong> as we'll be applying <strong>Window Attention</strong>. 3. Then we send batch <strong>batch windows</strong> for attention. 4. After attention has been applied we revert <strong>batch windows</strong> to <strong>batch feature maps</strong>. 5. Before sending the attention to applied features for output, this module applies <strong>Stochastic Depth</strong> regularization in the residual connection. Also, before applying <strong>Stochastic Depth</strong> it rescales the input with trainable parameters. Note that, this <strong>Stochastic Depth</strong> block hasn't been shown in the figure of the paper.</p> <p><img src="https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/block2.JPG" width=400></p> <h3 id="window">Window</h3> <p>In the <code>block</code> module, we have created <strong>windows</strong> before and after applying attention. Let's try to understand how we're creating windows, * Following module converts feature maps <code>(B, H, W, C)</code> to stacked windows <code>(B x H/h x W/w, h, w, C)</code> → <code>(num_windows_batch, window_size, window_size, channel)</code> * This module uses <code>reshape</code> & <code>transpose</code> to create these windows out of image instead of iterating over them.</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">Block</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""GCViT block.</span> <span class="sd"> Args:</span> <span class="sd"> window_size: window size.</span> <span class="sd"> num_heads: number of attention head.</span> <span class="sd"> global_query: apply global window attention</span> <span class="sd"> mlp_ratio: MLP ratio.</span> <span class="sd"> qkv_bias: bool argument for query, key, value learnable bias.</span> <span class="sd"> qk_scale: bool argument to scaling query, key.</span> <span class="sd"> drop: dropout rate.</span> <span class="sd"> attention_dropout: attention dropout rate.</span> <span class="sd"> path_drop: drop path rate.</span> <span class="sd"> activation: activation function.</span> <span class="sd"> layer_scale: layer scaling coefficient.</span> <span class="sd"> """</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> <span class="bp">self</span><span class="p">,</span> <span class="n">window_size</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">,</span> <span class="n">global_query</span><span class="p">,</span> <span class="n">mlp_ratio</span><span class="o">=</span><span class="mf">4.0</span><span class="p">,</span> <span class="n">qkv_bias</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">qk_scale</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dropout</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">attention_dropout</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">path_drop</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"gelu"</span><span class="p">,</span> <span class="n">layer_scale</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">,</span> <span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">window_size</span> <span class="o">=</span> <span class="n">window_size</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_heads</span> <span class="o">=</span> <span class="n">num_heads</span> <span class="bp">self</span><span class="o">.</span><span class="n">global_query</span> <span class="o">=</span> <span class="n">global_query</span> <span class="bp">self</span><span class="o">.</span><span class="n">mlp_ratio</span> <span class="o">=</span> <span class="n">mlp_ratio</span> <span class="bp">self</span><span class="o">.</span><span class="n">qkv_bias</span> <span class="o">=</span> <span class="n">qkv_bias</span> <span class="bp">self</span><span class="o">.</span><span class="n">qk_scale</span> <span class="o">=</span> <span class="n">qk_scale</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">dropout</span> <span class="bp">self</span><span class="o">.</span><span class="n">attention_dropout</span> <span class="o">=</span> <span class="n">attention_dropout</span> <span class="bp">self</span><span class="o">.</span><span class="n">path_drop</span> <span class="o">=</span> <span class="n">path_drop</span> <span class="bp">self</span><span class="o">.</span><span class="n">activation</span> <span class="o">=</span> <span class="n">activation</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_scale</span> <span class="o">=</span> <span class="n">layer_scale</span> <span class="k">def</span> <span class="nf">build</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_shape</span><span class="p">):</span> <span class="n">B</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">W</span><span class="p">,</span> <span class="n">C</span> <span class="o">=</span> <span class="n">input_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm1</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">LayerNormalization</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mf">1e-05</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"norm1"</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span> <span class="o">=</span> <span class="n">WindowAttention</span><span class="p">(</span> <span class="n">window_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">window_size</span><span class="p">,</span> <span class="n">num_heads</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">num_heads</span><span class="p">,</span> <span class="n">global_query</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">global_query</span><span class="p">,</span> <span class="n">qkv_bias</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">qkv_bias</span><span class="p">,</span> <span class="n">qk_scale</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">qk_scale</span><span class="p">,</span> <span class="n">attention_dropout</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">attention_dropout</span><span class="p">,</span> <span class="n">projection_dropout</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"attn"</span><span class="p">,</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">drop_path1</span> <span class="o">=</span> <span class="n">DropPath</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">path_drop</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">drop_path2</span> <span class="o">=</span> <span class="n">DropPath</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">path_drop</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm2</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">LayerNormalization</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mf">1e-05</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"norm2"</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">mlp</span> <span class="o">=</span> <span class="n">MLP</span><span class="p">(</span> <span class="n">hidden_features</span><span class="o">=</span><span class="nb">int</span><span class="p">(</span><span class="n">C</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">mlp_ratio</span><span class="p">),</span> <span class="n">dropout</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">activation</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"mlp"</span><span class="p">,</span> <span class="p">)</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_scale</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">gamma1</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">add_weight</span><span class="p">(</span> <span class="n">name</span><span class="o">=</span><span class="s2">"gamma1"</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">[</span><span class="n">C</span><span class="p">],</span> <span class="n">initializer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">initializers</span><span class="o">.</span><span class="n">Constant</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">layer_scale</span><span class="p">),</span> <span class="n">trainable</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">gamma2</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">add_weight</span><span class="p">(</span> <span class="n">name</span><span class="o">=</span><span class="s2">"gamma2"</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">[</span><span class="n">C</span><span class="p">],</span> <span class="n">initializer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">initializers</span><span class="o">.</span><span class="n">Constant</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">layer_scale</span><span class="p">),</span> <span class="n">trainable</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="p">)</span> <span class="k">else</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">gamma1</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="bp">self</span><span class="o">.</span><span class="n">gamma2</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_windows</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">H</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">window_size</span><span class="p">)</span> <span class="o">*</span> <span class="nb">int</span><span class="p">(</span><span class="n">W</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">window_size</span><span class="p">)</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">build</span><span class="p">(</span><span class="n">input_shape</span><span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">global_query</span><span class="p">:</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">q_global</span> <span class="o">=</span> <span class="n">inputs</span> <span class="k">else</span><span class="p">:</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">inputs</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="n">B</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">W</span><span class="p">,</span> <span class="n">C</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm1</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="c1"># create windows and concat them in batch axis</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">window_partition</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">window_size</span><span class="p">)</span> <span class="c1"># (B_, win_h, win_w, C)</span> <span class="c1"># flatten patch</span> <span class="n">x</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">window_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">window_size</span><span class="p">,</span> <span class="n">C</span><span class="p">])</span> <span class="c1"># attention</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">global_query</span><span class="p">:</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span><span class="p">([</span><span class="n">x</span><span class="p">,</span> <span class="n">q_global</span><span class="p">])</span> <span class="k">else</span><span class="p">:</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn</span><span class="p">([</span><span class="n">x</span><span class="p">])</span> <span class="c1"># reverse window partition</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">window_reverse</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">window_size</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">W</span><span class="p">,</span> <span class="n">C</span><span class="p">)</span> <span class="c1"># FFN</span> <span class="n">x</span> <span class="o">=</span> <span class="n">inputs</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">drop_path1</span><span class="p">(</span><span class="n">x</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">gamma1</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">drop_path2</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">gamma2</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">mlp</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">norm2</span><span class="p">(</span><span class="n">x</span><span class="p">)))</span> <span class="k">return</span> <span class="n">x</span> <span class="k">def</span> <span class="nf">window_partition</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">window_size</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""</span> <span class="sd"> Args:</span> <span class="sd"> x: (B, H, W, C)</span> <span class="sd"> window_size: window size</span> <span class="sd"> Returns:</span> <span class="sd"> local window features (num_windows*B, window_size, window_size, C)</span> <span class="sd"> """</span> <span class="n">B</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">W</span><span class="p">,</span> <span class="n">C</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span> <span class="n">x</span><span class="p">,</span> <span class="p">[</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">H</span> <span class="o">//</span> <span class="n">window_size</span><span class="p">,</span> <span class="n">window_size</span><span class="p">,</span> <span class="n">W</span> <span class="o">//</span> <span class="n">window_size</span><span class="p">,</span> <span class="n">window_size</span><span class="p">,</span> <span class="n">C</span><span class="p">,</span> <span class="p">],</span> <span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axes</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">5</span><span class="p">])</span> <span class="n">windows</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">window_size</span><span class="p">,</span> <span class="n">window_size</span><span class="p">,</span> <span class="n">C</span><span class="p">])</span> <span class="k">return</span> <span class="n">windows</span> <span class="k">def</span> <span class="nf">window_reverse</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">windows</span><span class="p">,</span> <span class="n">window_size</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">W</span><span class="p">,</span> <span class="n">C</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""</span> <span class="sd"> Args:</span> <span class="sd"> windows: local window features (num_windows*B, window_size, window_size, C)</span> <span class="sd"> window_size: Window size</span> <span class="sd"> H: Height of image</span> <span class="sd"> W: Width of image</span> <span class="sd"> C: Channel of image</span> <span class="sd"> Returns:</span> <span class="sd"> x: (B, H, W, C)</span> <span class="sd"> """</span> <span class="n">x</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span> <span class="n">windows</span><span class="p">,</span> <span class="p">[</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">H</span> <span class="o">//</span> <span class="n">window_size</span><span class="p">,</span> <span class="n">W</span> <span class="o">//</span> <span class="n">window_size</span><span class="p">,</span> <span class="n">window_size</span><span class="p">,</span> <span class="n">window_size</span><span class="p">,</span> <span class="n">C</span><span class="p">,</span> <span class="p">],</span> <span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axes</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">5</span><span class="p">])</span> <span class="n">x</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">W</span><span class="p">,</span> <span class="n">C</span><span class="p">])</span> <span class="k">return</span> <span class="n">x</span> </code></pre></div> <h3 id="level">Level</h3> <blockquote> <p><strong>Note:</strong> This module has both Transformer and CNN modules.</p> </blockquote> <p>In the model, the second module that we have used is <code>level</code>. Let's try to understand this module. As we can see from the <code>call</code> method, 1. First it creates <strong>global_token</strong> with a series of <code>FeatureExtraction</code> modules. As we'll see later that <code>FeatureExtraction</code> is nothing but a simple <strong>CNN</strong> based module. 2. Then it uses series of<code>Block</code> modules to apply <strong>local or global window attention</strong> depending on depth level. 3. Finally, it uses <code>ReduceSize</code> to reduce the dimension of <strong>contextualized features</strong>.</p> <blockquote> <p>Summary: feature_map → global_token → local/global window attention → dowsample</p> </blockquote> <p><img src="https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/level.png" width=400></p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">Level</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""GCViT level.</span> <span class="sd"> Args:</span> <span class="sd"> depth: number of layers in each stage.</span> <span class="sd"> num_heads: number of heads in each stage.</span> <span class="sd"> window_size: window size in each stage.</span> <span class="sd"> keepdims: dims to keep in FeatureExtraction.</span> <span class="sd"> downsample: bool argument for down-sampling.</span> <span class="sd"> mlp_ratio: MLP ratio.</span> <span class="sd"> qkv_bias: bool argument for query, key, value learnable bias.</span> <span class="sd"> qk_scale: bool argument to scaling query, key.</span> <span class="sd"> drop: dropout rate.</span> <span class="sd"> attention_dropout: attention dropout rate.</span> <span class="sd"> path_drop: drop path rate.</span> <span class="sd"> layer_scale: layer scaling coefficient.</span> <span class="sd"> """</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> <span class="bp">self</span><span class="p">,</span> <span class="n">depth</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">,</span> <span class="n">window_size</span><span class="p">,</span> <span class="n">keepdims</span><span class="p">,</span> <span class="n">downsample</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">mlp_ratio</span><span class="o">=</span><span class="mf">4.0</span><span class="p">,</span> <span class="n">qkv_bias</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">qk_scale</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">dropout</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">attention_dropout</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">path_drop</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">layer_scale</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">,</span> <span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">depth</span> <span class="o">=</span> <span class="n">depth</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_heads</span> <span class="o">=</span> <span class="n">num_heads</span> <span class="bp">self</span><span class="o">.</span><span class="n">window_size</span> <span class="o">=</span> <span class="n">window_size</span> <span class="bp">self</span><span class="o">.</span><span class="n">keepdims</span> <span class="o">=</span> <span class="n">keepdims</span> <span class="bp">self</span><span class="o">.</span><span class="n">downsample</span> <span class="o">=</span> <span class="n">downsample</span> <span class="bp">self</span><span class="o">.</span><span class="n">mlp_ratio</span> <span class="o">=</span> <span class="n">mlp_ratio</span> <span class="bp">self</span><span class="o">.</span><span class="n">qkv_bias</span> <span class="o">=</span> <span class="n">qkv_bias</span> <span class="bp">self</span><span class="o">.</span><span class="n">qk_scale</span> <span class="o">=</span> <span class="n">qk_scale</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">dropout</span> <span class="bp">self</span><span class="o">.</span><span class="n">attention_dropout</span> <span class="o">=</span> <span class="n">attention_dropout</span> <span class="bp">self</span><span class="o">.</span><span class="n">path_drop</span> <span class="o">=</span> <span class="n">path_drop</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_scale</span> <span class="o">=</span> <span class="n">layer_scale</span> <span class="k">def</span> <span class="nf">build</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_shape</span><span class="p">):</span> <span class="n">path_drop</span> <span class="o">=</span> <span class="p">(</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">path_drop</span><span class="p">]</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">depth</span> <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">path_drop</span><span class="p">,</span> <span class="nb">list</span><span class="p">)</span> <span class="k">else</span> <span class="bp">self</span><span class="o">.</span><span class="n">path_drop</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">blocks</span> <span class="o">=</span> <span class="p">[</span> <span class="n">Block</span><span class="p">(</span> <span class="n">window_size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">window_size</span><span class="p">,</span> <span class="n">num_heads</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">num_heads</span><span class="p">,</span> <span class="n">global_query</span><span class="o">=</span><span class="nb">bool</span><span class="p">(</span><span class="n">i</span> <span class="o">%</span> <span class="mi">2</span><span class="p">),</span> <span class="n">mlp_ratio</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">mlp_ratio</span><span class="p">,</span> <span class="n">qkv_bias</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">qkv_bias</span><span class="p">,</span> <span class="n">qk_scale</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">qk_scale</span><span class="p">,</span> <span class="n">dropout</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">,</span> <span class="n">attention_dropout</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">attention_dropout</span><span class="p">,</span> <span class="n">path_drop</span><span class="o">=</span><span class="n">path_drop</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">layer_scale</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">layer_scale</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="sa">f</span><span class="s2">"blocks_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s2">"</span><span class="p">,</span> <span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">depth</span><span class="p">)</span> <span class="p">]</span> <span class="bp">self</span><span class="o">.</span><span class="n">down</span> <span class="o">=</span> <span class="n">ReduceSize</span><span class="p">(</span><span class="n">keepdims</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"downsample"</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">q_global_gen</span> <span class="o">=</span> <span class="n">GlobalQueryGenerator</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">keepdims</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"q_global_gen"</span><span class="p">)</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">build</span><span class="p">(</span><span class="n">input_shape</span><span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="n">inputs</span> <span class="n">q_global</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">q_global_gen</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># shape: (B, win_size, win_size, C)</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">blk</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">blocks</span><span class="p">):</span> <span class="k">if</span> <span class="n">i</span> <span class="o">%</span> <span class="mi">2</span><span class="p">:</span> <span class="n">x</span> <span class="o">=</span> <span class="n">blk</span><span class="p">([</span><span class="n">x</span><span class="p">,</span> <span class="n">q_global</span><span class="p">])</span> <span class="c1"># shape: (B, H, W, C)</span> <span class="k">else</span><span class="p">:</span> <span class="n">x</span> <span class="o">=</span> <span class="n">blk</span><span class="p">([</span><span class="n">x</span><span class="p">])</span> <span class="c1"># shape: (B, H, W, C)</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">downsample</span><span class="p">:</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">down</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># shape: (B, H//2, W//2, 2*C)</span> <span class="k">return</span> <span class="n">x</span> </code></pre></div> <h3 id="model">Model</h3> <p>Let's directly jump to the model. As we can see from the <code>call</code> method, 1. It creates patch embeddings from an image. This layer doesn't flattens these embeddings which means output of this module will be <code>(batch, height/window_size, width/window_size, embed_dim)</code> instead of <code>(batch, height x width/window_size^2, embed_dim)</code>. 2. Then it applies <code>Dropout</code> module which randomly sets input units to 0. 3. It passes these embeddings to series of <code>Level</code> modules which we are calling <code>level</code> where, 1. Global token is generated 1. Both local & global attention is applied 1. Finally downsample is applied. 4. So, output after <code>n</code> number of <strong>levels</strong>, shape: <code>(batch, width/window_size x 2^{n-1}, width/window_size x 2^{n-1}, embed_dim x 2^{n-1})</code>. In the last layer, paper doesn't use <strong>downsample</strong> and increase <strong>channels</strong>. 5. Output of above layer is normalized using <code>LayerNormalization</code> module. 6. In the head, 2D features are converted to 1D features with <code>Pooling</code> module. Output shape after this module is <code>(batch, embed_dim x 2^{n-1})</code> 7. Finally, pooled features are sent to <code>Dense/Linear</code> module for classification.</p> <blockquote> <p>Sumamry: image → (patchs + embedding) → dropout → (attention + feature extraction) → normalizaion → pooling → classify</p> </blockquote> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">GCViT</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""GCViT model.</span> <span class="sd"> Args:</span> <span class="sd"> window_size: window size in each stage.</span> <span class="sd"> embed_dim: feature size dimension.</span> <span class="sd"> depths: number of layers in each stage.</span> <span class="sd"> num_heads: number of heads in each stage.</span> <span class="sd"> drop_rate: dropout rate.</span> <span class="sd"> mlp_ratio: MLP ratio.</span> <span class="sd"> qkv_bias: bool argument for query, key, value learnable bias.</span> <span class="sd"> qk_scale: bool argument to scaling query, key.</span> <span class="sd"> attention_dropout: attention dropout rate.</span> <span class="sd"> path_drop: drop path rate.</span> <span class="sd"> layer_scale: layer scaling coefficient.</span> <span class="sd"> num_classes: number of classes.</span> <span class="sd"> head_activation: activation function for head.</span> <span class="sd"> """</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> <span class="bp">self</span><span class="p">,</span> <span class="n">window_size</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">,</span> <span class="n">depths</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">,</span> <span class="n">drop_rate</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">mlp_ratio</span><span class="o">=</span><span class="mf">3.0</span><span class="p">,</span> <span class="n">qkv_bias</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">qk_scale</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">attention_dropout</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">path_drop</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">layer_scale</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">num_classes</span><span class="o">=</span><span class="mi">1000</span><span class="p">,</span> <span class="n">head_activation</span><span class="o">=</span><span class="s2">"softmax"</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">,</span> <span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">window_size</span> <span class="o">=</span> <span class="n">window_size</span> <span class="bp">self</span><span class="o">.</span><span class="n">embed_dim</span> <span class="o">=</span> <span class="n">embed_dim</span> <span class="bp">self</span><span class="o">.</span><span class="n">depths</span> <span class="o">=</span> <span class="n">depths</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_heads</span> <span class="o">=</span> <span class="n">num_heads</span> <span class="bp">self</span><span class="o">.</span><span class="n">drop_rate</span> <span class="o">=</span> <span class="n">drop_rate</span> <span class="bp">self</span><span class="o">.</span><span class="n">mlp_ratio</span> <span class="o">=</span> <span class="n">mlp_ratio</span> <span class="bp">self</span><span class="o">.</span><span class="n">qkv_bias</span> <span class="o">=</span> <span class="n">qkv_bias</span> <span class="bp">self</span><span class="o">.</span><span class="n">qk_scale</span> <span class="o">=</span> <span class="n">qk_scale</span> <span class="bp">self</span><span class="o">.</span><span class="n">attention_dropout</span> <span class="o">=</span> <span class="n">attention_dropout</span> <span class="bp">self</span><span class="o">.</span><span class="n">path_drop</span> <span class="o">=</span> <span class="n">path_drop</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_scale</span> <span class="o">=</span> <span class="n">layer_scale</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_classes</span> <span class="o">=</span> <span class="n">num_classes</span> <span class="bp">self</span><span class="o">.</span><span class="n">head_activation</span> <span class="o">=</span> <span class="n">head_activation</span> <span class="bp">self</span><span class="o">.</span><span class="n">patch_embed</span> <span class="o">=</span> <span class="n">PatchEmbed</span><span class="p">(</span><span class="n">embed_dim</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"patch_embed"</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">pos_drop</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">drop_rate</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"pos_drop"</span><span class="p">)</span> <span class="n">path_drops</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">linspace</span><span class="p">(</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">path_drop</span><span class="p">,</span> <span class="nb">sum</span><span class="p">(</span><span class="n">depths</span><span class="p">))</span> <span class="n">keepdims</span> <span class="o">=</span> <span class="p">[(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="p">(</span><span class="mi">1</span><span class="p">,),</span> <span class="p">(</span><span class="mi">1</span><span class="p">,)]</span> <span class="bp">self</span><span class="o">.</span><span class="n">levels</span> <span class="o">=</span> <span class="p">[]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">depths</span><span class="p">)):</span> <span class="n">path_drop</span> <span class="o">=</span> <span class="n">path_drops</span><span class="p">[</span><span class="nb">sum</span><span class="p">(</span><span class="n">depths</span><span class="p">[:</span><span class="n">i</span><span class="p">])</span> <span class="p">:</span> <span class="nb">sum</span><span class="p">(</span><span class="n">depths</span><span class="p">[:</span> <span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">])]</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span> <span class="n">level</span> <span class="o">=</span> <span class="n">Level</span><span class="p">(</span> <span class="n">depth</span><span class="o">=</span><span class="n">depths</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">num_heads</span><span class="o">=</span><span class="n">num_heads</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">window_size</span><span class="o">=</span><span class="n">window_size</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">keepdims</span><span class="o">=</span><span class="n">keepdims</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">downsample</span><span class="o">=</span><span class="p">(</span><span class="n">i</span> <span class="o"><</span> <span class="nb">len</span><span class="p">(</span><span class="n">depths</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span><span class="p">),</span> <span class="n">mlp_ratio</span><span class="o">=</span><span class="n">mlp_ratio</span><span class="p">,</span> <span class="n">qkv_bias</span><span class="o">=</span><span class="n">qkv_bias</span><span class="p">,</span> <span class="n">qk_scale</span><span class="o">=</span><span class="n">qk_scale</span><span class="p">,</span> <span class="n">dropout</span><span class="o">=</span><span class="n">drop_rate</span><span class="p">,</span> <span class="n">attention_dropout</span><span class="o">=</span><span class="n">attention_dropout</span><span class="p">,</span> <span class="n">path_drop</span><span class="o">=</span><span class="n">path_drop</span><span class="p">,</span> <span class="n">layer_scale</span><span class="o">=</span><span class="n">layer_scale</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="sa">f</span><span class="s2">"levels_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s2">"</span><span class="p">,</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">levels</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">level</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">LayerNormalization</span><span class="p">(</span><span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="n">epsilon</span><span class="o">=</span><span class="mf">1e-05</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"norm"</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">pool</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">GlobalAvgPool2D</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">"pool"</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">head</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">num_classes</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"head"</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="n">head_activation</span><span class="p">)</span> <span class="k">def</span> <span class="nf">build</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_shape</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">build</span><span class="p">(</span><span class="n">input_shape</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">built</span> <span class="o">=</span> <span class="kc">True</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">patch_embed</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="c1"># shape: (B, H, W, C)</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pos_drop</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">for</span> <span class="n">level</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">levels</span><span class="p">:</span> <span class="n">x</span> <span class="o">=</span> <span class="n">level</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># shape: (B, H_, W_, C_)</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">pool</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># shape: (B, C__)</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">head</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">x</span> <span class="k">def</span> <span class="nf">build_graph</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_shape</span><span class="o">=</span><span class="p">(</span><span class="mi">224</span><span class="p">,</span> <span class="mi">224</span><span class="p">,</span> <span class="mi">3</span><span class="p">)):</span> <span class="w"> </span><span class="sd">"""</span> <span class="sd"> ref: https://www.kaggle.com/code/ipythonx/tf-hybrid-efficientnet-swin-transformer-gradcam</span> <span class="sd"> """</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="n">input_shape</span><span class="p">)</span> <span class="k">return</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="o">=</span><span class="p">[</span><span class="n">x</span><span class="p">],</span> <span class="n">outputs</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">call</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">name</span><span class="p">)</span> <span class="k">def</span> <span class="nf">summary</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_shape</span><span class="o">=</span><span class="p">(</span><span class="mi">224</span><span class="p">,</span> <span class="mi">224</span><span class="p">,</span> <span class="mi">3</span><span class="p">)):</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">build_graph</span><span class="p">(</span><span class="n">input_shape</span><span class="p">)</span><span class="o">.</span><span class="n">summary</span><span class="p">()</span> </code></pre></div> <hr /> <h2 id="build-model">Build Model</h2> <ul> <li>Let's build a complete model with all the modules that we've explained above. We'll build <strong>GCViT-XXTiny</strong> model with the configuration mentioned in the paper.</li> <li>Also we'll load the ported official <strong>pre-trained</strong> weights and try for some predictions.</li> </ul> <div class="codehilite"><pre><span></span><code><span class="c1"># Model Configs</span> <span class="n">config</span> <span class="o">=</span> <span class="p">{</span> <span class="s2">"window_size"</span><span class="p">:</span> <span class="p">(</span><span class="mi">7</span><span class="p">,</span> <span class="mi">7</span><span class="p">,</span> <span class="mi">14</span><span class="p">,</span> <span class="mi">7</span><span class="p">),</span> <span class="s2">"embed_dim"</span><span class="p">:</span> <span class="mi">64</span><span class="p">,</span> <span class="s2">"depths"</span><span class="p">:</span> <span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span> <span class="s2">"num_heads"</span><span class="p">:</span> <span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="mi">8</span><span class="p">,</span> <span class="mi">16</span><span class="p">),</span> <span class="s2">"mlp_ratio"</span><span class="p">:</span> <span class="mf">3.0</span><span class="p">,</span> <span class="s2">"path_drop"</span><span class="p">:</span> <span class="mf">0.2</span><span class="p">,</span> <span class="p">}</span> <span class="n">ckpt_link</span> <span class="o">=</span> <span class="p">(</span> <span class="s2">"https://github.com/awsaf49/gcvit-tf/releases/download/v1.1.6/gcvitxxtiny.keras"</span> <span class="p">)</span> <span class="c1"># Build Model</span> <span class="n">model</span> <span class="o">=</span> <span class="n">GCViT</span><span class="p">(</span><span class="o">**</span><span class="n">config</span><span class="p">)</span> <span class="n">inp</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">224</span><span class="p">,</span> <span class="mi">224</span><span class="p">,</span> <span class="mi">3</span><span class="p">)))</span> <span class="n">out</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">inp</span><span class="p">)</span> <span class="c1"># Load Weights</span> <span class="n">ckpt_path</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">get_file</span><span class="p">(</span><span class="n">ckpt_link</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">"/"</span><span class="p">)[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">ckpt_link</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">load_weights</span><span class="p">(</span><span class="n">ckpt_path</span><span class="p">)</span> <span class="c1"># Summary</span> <span class="n">model</span><span class="o">.</span><span class="n">summary</span><span class="p">((</span><span class="mi">224</span><span class="p">,</span> <span class="mi">224</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Downloading data from https://github.com/awsaf49/gcvit-tf/releases/download/v1.1.6/gcvitxxtiny.keras 48767519/48767519 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step </code></pre></div> </div> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold">Model: "gc_vi_t"</span> </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┓ ┃<span style="font-weight: bold"> Layer (type) </span>┃<span style="font-weight: bold"> Output Shape </span>┃<span style="font-weight: bold"> Param # </span>┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━┩ │ input_layer (<span style="color: #0087ff; text-decoration-color: #0087ff">InputLayer</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">224</span>, <span style="color: #00af00; text-decoration-color: #00af00">224</span>, <span style="color: #00af00; text-decoration-color: #00af00">3</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ patch_embed (<span style="color: #0087ff; text-decoration-color: #0087ff">PatchEmbed</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">56</span>, <span style="color: #00af00; text-decoration-color: #00af00">56</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">45,632</span> │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ pos_drop (<span style="color: #0087ff; text-decoration-color: #0087ff">Dropout</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">56</span>, <span style="color: #00af00; text-decoration-color: #00af00">56</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ levels_0 (<span style="color: #0087ff; text-decoration-color: #0087ff">Level</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">28</span>, <span style="color: #00af00; text-decoration-color: #00af00">28</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">180,964</span> │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ levels_1 (<span style="color: #0087ff; text-decoration-color: #0087ff">Level</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">14</span>, <span style="color: #00af00; text-decoration-color: #00af00">14</span>, <span style="color: #00af00; text-decoration-color: #00af00">256</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">688,456</span> │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ levels_2 (<span style="color: #0087ff; text-decoration-color: #0087ff">Level</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">7</span>, <span style="color: #00af00; text-decoration-color: #00af00">7</span>, <span style="color: #00af00; text-decoration-color: #00af00">512</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">5,170,608</span> │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ levels_3 (<span style="color: #0087ff; text-decoration-color: #0087ff">Level</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">7</span>, <span style="color: #00af00; text-decoration-color: #00af00">7</span>, <span style="color: #00af00; text-decoration-color: #00af00">512</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">5,395,744</span> │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ norm (<span style="color: #0087ff; text-decoration-color: #0087ff">LayerNormalization</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">7</span>, <span style="color: #00af00; text-decoration-color: #00af00">7</span>, <span style="color: #00af00; text-decoration-color: #00af00">512</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">1,024</span> │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ pool (<span style="color: #0087ff; text-decoration-color: #0087ff">GlobalAveragePooling2D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">512</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├────────────────────────────────────┼───────────────────────────────┼─────────────┤ │ head (<span style="color: #0087ff; text-decoration-color: #0087ff">Dense</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">1000</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">513,000</span> │ └────────────────────────────────────┴───────────────────────────────┴─────────────┘ </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Total params: </span><span style="color: #00af00; text-decoration-color: #00af00">11,995,428</span> (45.76 MB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">11,995,428</span> (45.76 MB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Non-trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">0</span> (0.00 B) </pre> <hr /> <h2 id="sanity-check-for-pretrained-weights">Sanity check for Pre-Trained Weights</h2> <div class="codehilite"><pre><span></span><code><span class="n">img</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">applications</span><span class="o">.</span><span class="n">imagenet_utils</span><span class="o">.</span><span class="n">preprocess_input</span><span class="p">(</span> <span class="n">chelsea</span><span class="p">(),</span> <span class="n">mode</span><span class="o">=</span><span class="s2">"torch"</span> <span class="p">)</span> <span class="c1"># Chelsea the cat</span> <span class="n">img</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">resize</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="p">(</span><span class="mi">224</span><span class="p">,</span> <span class="mi">224</span><span class="p">))[</span><span class="kc">None</span><span class="p">,]</span> <span class="c1"># resize & create batch</span> <span class="n">pred</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">img</span><span class="p">)</span> <span class="n">pred_dec</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">applications</span><span class="o">.</span><span class="n">imagenet_utils</span><span class="o">.</span><span class="n">decode_predictions</span><span class="p">(</span><span class="n">pred</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"</span><span class="se">\n</span><span class="s2"># Image:"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">6</span><span class="p">,</span> <span class="mi">6</span><span class="p">))</span> <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">chelsea</span><span class="p">())</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> <span class="nb">print</span><span class="p">()</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"# Prediction (Top 5):"</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">5</span><span class="p">):</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"</span><span class="si">{:<12}</span><span class="s2"> : </span><span class="si">{:0.2f}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">pred_dec</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">1</span><span class="p">],</span> <span class="n">pred_dec</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">2</span><span class="p">]))</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json 35363/35363 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code># Image: </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/image_classification_using_global_context_vision_transformer/image_classification_using_global_context_vision_transformer_24_1.png" /></p> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code># Prediction (Top 5): Egyptian_cat : 0.72 tiger_cat : 0.04 tabby : 0.03 crossword_puzzle : 0.01 panpipe : 0.00 </code></pre></div> </div> <h1 id="gcvit">Fine-tune <strong>GCViT</strong> Model</h1> <p>In the following cells, we will fine-tune <strong>GCViT</strong> model on Flower Dataset which consists <code>104</code> classes.</p> <h3 id="configs">Configs</h3> <div class="codehilite"><pre><span></span><code><span class="c1"># Model</span> <span class="n">IMAGE_SIZE</span> <span class="o">=</span> <span class="p">(</span><span class="mi">224</span><span class="p">,</span> <span class="mi">224</span><span class="p">)</span> <span class="c1"># Hyper Params</span> <span class="n">BATCH_SIZE</span> <span class="o">=</span> <span class="mi">32</span> <span class="n">EPOCHS</span> <span class="o">=</span> <span class="mi">5</span> <span class="c1"># Dataset</span> <span class="n">CLASSES</span> <span class="o">=</span> <span class="p">[</span> <span class="s2">"dandelion"</span><span class="p">,</span> <span class="s2">"daisy"</span><span class="p">,</span> <span class="s2">"tulips"</span><span class="p">,</span> <span class="s2">"sunflowers"</span><span class="p">,</span> <span class="s2">"roses"</span><span class="p">,</span> <span class="p">]</span> <span class="c1"># don't change the order</span> <span class="c1"># Other constants</span> <span class="n">MEAN</span> <span class="o">=</span> <span class="mi">255</span> <span class="o">*</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mf">0.485</span><span class="p">,</span> <span class="mf">0.456</span><span class="p">,</span> <span class="mf">0.406</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"float32"</span><span class="p">)</span> <span class="c1"># imagenet mean</span> <span class="n">STD</span> <span class="o">=</span> <span class="mi">255</span> <span class="o">*</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mf">0.229</span><span class="p">,</span> <span class="mf">0.224</span><span class="p">,</span> <span class="mf">0.225</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"float32"</span><span class="p">)</span> <span class="c1"># imagenet std</span> <span class="n">AUTO</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">AUTOTUNE</span> </code></pre></div> <hr /> <h2 id="data-loader">Data Loader</h2> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">make_dataset</span><span class="p">(</span><span class="n">dataset</span><span class="p">:</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="p">,</span> <span class="n">train</span><span class="p">:</span> <span class="nb">bool</span><span class="p">,</span> <span class="n">image_size</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="n">IMAGE_SIZE</span><span class="p">):</span> <span class="k">def</span> <span class="nf">preprocess</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="n">label</span><span class="p">):</span> <span class="c1"># for training, do augmentation</span> <span class="k">if</span> <span class="n">train</span><span class="p">:</span> <span class="k">if</span> <span class="n">tf</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">[])</span> <span class="o">></span> <span class="mf">0.5</span><span class="p">:</span> <span class="n">image</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">flip_left_right</span><span class="p">(</span><span class="n">image</span><span class="p">)</span> <span class="n">image</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">resize</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="n">image_size</span><span class="p">,</span> <span class="n">method</span><span class="o">=</span><span class="s2">"bicubic"</span><span class="p">)</span> <span class="n">image</span> <span class="o">=</span> <span class="p">(</span><span class="n">image</span> <span class="o">-</span> <span class="n">MEAN</span><span class="p">)</span> <span class="o">/</span> <span class="n">STD</span> <span class="c1"># normalization</span> <span class="k">return</span> <span class="n">image</span><span class="p">,</span> <span class="n">label</span> <span class="k">if</span> <span class="n">train</span><span class="p">:</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="n">BATCH_SIZE</span> <span class="o">*</span> <span class="mi">10</span><span class="p">)</span> <span class="k">return</span> <span class="n">dataset</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">preprocess</span><span class="p">,</span> <span class="n">AUTO</span><span class="p">)</span><span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="n">BATCH_SIZE</span><span class="p">)</span><span class="o">.</span><span class="n">prefetch</span><span class="p">(</span><span class="n">AUTO</span><span class="p">)</span> </code></pre></div> <h3 id="flower-dataset">Flower Dataset</h3> <div class="codehilite"><pre><span></span><code><span class="n">train_dataset</span><span class="p">,</span> <span class="n">val_dataset</span> <span class="o">=</span> <span class="n">tfds</span><span class="o">.</span><span class="n">load</span><span class="p">(</span> <span class="s2">"tf_flowers"</span><span class="p">,</span> <span class="n">split</span><span class="o">=</span><span class="p">[</span><span class="s2">"train[:90%]"</span><span class="p">,</span> <span class="s2">"train[90%:]"</span><span class="p">],</span> <span class="n">as_supervised</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">try_gcs</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="c1"># gcs_path is necessary for tpu,</span> <span class="p">)</span> <span class="n">train_dataset</span> <span class="o">=</span> <span class="n">make_dataset</span><span class="p">(</span><span class="n">train_dataset</span><span class="p">,</span> <span class="kc">True</span><span class="p">)</span> <span class="n">val_dataset</span> <span class="o">=</span> <span class="n">make_dataset</span><span class="p">(</span><span class="n">val_dataset</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Downloading and preparing dataset 218.21 MiB (download: 218.21 MiB, generated: 221.83 MiB, total: 440.05 MiB) to /root/tensorflow_datasets/tf_flowers/3.0.1... Dl Completed...: 0%| | 0/5 [00:00<?, ? file/s] Dataset tf_flowers downloaded and prepared to /root/tensorflow_datasets/tf_flowers/3.0.1. Subsequent calls will reuse this data. </code></pre></div> </div> <h3 id="rebuild-model-for-flower-dataset">Re-Build Model for Flower Dataset</h3> <div class="codehilite"><pre><span></span><code><span class="c1"># Re-Build Model</span> <span class="n">model</span> <span class="o">=</span> <span class="n">GCViT</span><span class="p">(</span><span class="o">**</span><span class="n">config</span><span class="p">,</span> <span class="n">num_classes</span><span class="o">=</span><span class="mi">104</span><span class="p">)</span> <span class="n">inp</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">224</span><span class="p">,</span> <span class="mi">224</span><span class="p">,</span> <span class="mi">3</span><span class="p">)))</span> <span class="n">out</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">inp</span><span class="p">)</span> <span class="c1"># Load Weights</span> <span class="n">ckpt_path</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">get_file</span><span class="p">(</span><span class="n">ckpt_link</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">"/"</span><span class="p">)[</span><span class="o">-</span><span class="mi">1</span><span class="p">],</span> <span class="n">ckpt_link</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">load_weights</span><span class="p">(</span><span class="n">ckpt_path</span><span class="p">,</span> <span class="n">skip_mismatch</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">loss</span><span class="o">=</span><span class="s2">"sparse_categorical_crossentropy"</span><span class="p">,</span> <span class="n">optimizer</span><span class="o">=</span><span class="s2">"adam"</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="s2">"accuracy"</span><span class="p">]</span> <span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>/usr/local/lib/python3.10/dist-packages/keras/src/saving/saving_lib.py:269: UserWarning: A total of 1 objects could not be loaded. Example error message for object <Dense name=head, built=True>: </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Layer 'head' expected 2 variables, but received 0 variables during loading. Expected: ['kernel', 'bias'] </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>List of objects that could not be loaded: [<Dense name=head, built=True>] warnings.warn(msg) </code></pre></div> </div> <h3 id="training">Training</h3> <div class="codehilite"><pre><span></span><code><span class="n">history</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span> <span class="n">train_dataset</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="n">val_dataset</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="n">EPOCHS</span><span class="p">,</span> <span class="n">verbose</span><span class="o">=</span><span class="mi">1</span> <span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 1/5 104/104 ━━━━━━━━━━━━━━━━━━━━ 153s 581ms/step - accuracy: 0.5140 - loss: 1.4615 - val_accuracy: 0.8828 - val_loss: 0.3485 Epoch 2/5 104/104 ━━━━━━━━━━━━━━━━━━━━ 7s 69ms/step - accuracy: 0.8775 - loss: 0.3437 - val_accuracy: 0.8828 - val_loss: 0.3508 Epoch 3/5 104/104 ━━━━━━━━━━━━━━━━━━━━ 7s 68ms/step - accuracy: 0.8937 - loss: 0.2918 - val_accuracy: 0.9019 - val_loss: 0.2953 Epoch 4/5 104/104 ━━━━━━━━━━━━━━━━━━━━ 7s 68ms/step - accuracy: 0.9232 - loss: 0.2397 - val_accuracy: 0.9183 - val_loss: 0.2212 Epoch 5/5 104/104 ━━━━━━━━━━━━━━━━━━━━ 7s 68ms/step - accuracy: 0.9456 - loss: 0.1645 - val_accuracy: 0.9210 - val_loss: 0.2897 </code></pre></div> </div> <hr /> <h2 id="reference">Reference</h2> <ul> <li><a href="https://github.com/awsaf49/gcvit-tf">gcvit-tf - A Python library for GCViT with TF2.0</a></li> <li><a href="https://github.com/NVlabs/GCVit">gcvit - Official codebase for GCViT</a></li> </ul> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#image-classification-using-global-context-vision-transformer'>Image Classification using Global Context Vision Transformer</a> </div> <div class='k-outline-depth-1'> <a href='#setup'>Setup</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#motivation'>Motivation</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#architecture'>Architecture</a> </div> <div class='k-outline-depth-3'> <a href='#unit-blocks'>Unit Blocks</a> </div> <div class='k-outline-depth-3'> <a href='#stem'>Stem</a> </div> <div class='k-outline-depth-3'> <a href='#global-token-gen'>Global Token Gen.</a> </div> <div class='k-outline-depth-3'> <a href='#attention'>Attention</a> </div> <div class='k-outline-depth-3'> <a href='#block'>Block</a> </div> <div class='k-outline-depth-3'> <a href='#window'>Window</a> </div> <div class='k-outline-depth-3'> <a href='#level'>Level</a> </div> <div class='k-outline-depth-3'> <a href='#model'>Model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#build-model'>Build Model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#sanity-check-for-pretrained-weights'>Sanity check for Pre-Trained Weights</a> </div> <div class='k-outline-depth-1'> <a href='#finetune-gcvit-model'>Fine-tune **GCViT** Model</a> </div> <div class='k-outline-depth-3'> <a href='#configs'>Configs</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#data-loader'>Data Loader</a> </div> <div class='k-outline-depth-3'> <a href='#flower-dataset'>Flower Dataset</a> </div> <div class='k-outline-depth-3'> <a href='#rebuild-model-for-flower-dataset'>Re-Build Model for Flower Dataset</a> </div> <div class='k-outline-depth-3'> <a href='#training'>Training</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#reference'>Reference</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>