CINXE.COM
Train a Vision Transformer on small datasets
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/vision/vit_small_ds/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Train a Vision Transformer on small datasets"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Train a Vision Transformer on small datasets"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Train a Vision Transformer on small datasets</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink active" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink2" href="/examples/vision/image_classification_from_scratch/">Image classification from scratch</a> <a class="nav-sublink2" href="/examples/vision/mnist_convnet/">Simple MNIST convnet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_efficientnet_fine_tuning/">Image classification via fine-tuning with EfficientNet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_with_vision_transformer/">Image classification with Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/attention_mil_classification/">Classification using Attention-based Deep Multiple Instance Learning</a> <a class="nav-sublink2" href="/examples/vision/mlp_image_classification/">Image classification with modern MLP models</a> <a class="nav-sublink2" href="/examples/vision/mobilevit/">A mobile-friendly Transformer-based model for image classification</a> <a class="nav-sublink2" href="/examples/vision/xray_classification_with_tpus/">Pneumonia Classification on TPU</a> <a class="nav-sublink2" href="/examples/vision/cct/">Compact Convolutional Transformers</a> <a class="nav-sublink2" href="/examples/vision/convmixer/">Image classification with ConvMixer</a> <a class="nav-sublink2" href="/examples/vision/eanet/">Image classification with EANet (External Attention Transformer)</a> <a class="nav-sublink2" href="/examples/vision/involution/">Involutional neural networks</a> <a class="nav-sublink2" href="/examples/vision/perceiver_image_classification/">Image classification with Perceiver</a> <a class="nav-sublink2" href="/examples/vision/reptile/">Few-Shot learning with Reptile</a> <a class="nav-sublink2" href="/examples/vision/semisupervised_simclr/">Semi-supervised image classification using contrastive pretraining with SimCLR</a> <a class="nav-sublink2" href="/examples/vision/swin_transformers/">Image classification with Swin Transformers</a> <a class="nav-sublink2 active" href="/examples/vision/vit_small_ds/">Train a Vision Transformer on small datasets</a> <a class="nav-sublink2" href="/examples/vision/shiftvit/">A Vision Transformer without Attention</a> <a class="nav-sublink2" href="/examples/vision/image_classification_using_global_context_vision_transformer/">Image Classification using Global Context Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/oxford_pets_image_segmentation/">Image segmentation with a U-Net-like architecture</a> <a class="nav-sublink2" href="/examples/vision/deeplabv3_plus/">Multiclass semantic segmentation using DeepLabV3+</a> <a class="nav-sublink2" href="/examples/vision/basnet_segmentation/">Highly accurate boundaries segmentation using BASNet</a> <a class="nav-sublink2" href="/examples/vision/fully_convolutional_network/">Image Segmentation using Composable Fully-Convolutional Networks</a> <a class="nav-sublink2" href="/examples/vision/retinanet/">Object Detection with RetinaNet</a> <a class="nav-sublink2" href="/examples/vision/keypoint_detection/">Keypoint Detection with Transfer Learning</a> <a class="nav-sublink2" href="/examples/vision/object_detection_using_vision_transformer/">Object detection with Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/3D_image_classification/">3D image classification from CT scans</a> <a class="nav-sublink2" href="/examples/vision/depth_estimation/">Monocular depth estimation</a> <a class="nav-sublink2" href="/examples/vision/nerf/">3D volumetric rendering with NeRF</a> <a class="nav-sublink2" href="/examples/vision/pointnet_segmentation/">Point cloud segmentation with PointNet</a> <a class="nav-sublink2" href="/examples/vision/pointnet/">Point cloud classification</a> <a class="nav-sublink2" href="/examples/vision/captcha_ocr/">OCR model for reading Captchas</a> <a class="nav-sublink2" href="/examples/vision/handwriting_recognition/">Handwriting recognition</a> <a class="nav-sublink2" href="/examples/vision/autoencoder/">Convolutional autoencoder for image denoising</a> <a class="nav-sublink2" href="/examples/vision/mirnet/">Low-light image enhancement using MIRNet</a> <a class="nav-sublink2" href="/examples/vision/super_resolution_sub_pixel/">Image Super-Resolution using an Efficient Sub-Pixel CNN</a> <a class="nav-sublink2" href="/examples/vision/edsr/">Enhanced Deep Residual Networks for single-image super-resolution</a> <a class="nav-sublink2" href="/examples/vision/zero_dce/">Zero-DCE for low-light image enhancement</a> <a class="nav-sublink2" href="/examples/vision/cutmix/">CutMix data augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/mixup/">MixUp augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/randaugment/">RandAugment for Image Classification for Improved Robustness</a> <a class="nav-sublink2" href="/examples/vision/image_captioning/">Image captioning</a> <a class="nav-sublink2" href="/examples/vision/nl_image_search/">Natural language image search with a Dual Encoder</a> <a class="nav-sublink2" href="/examples/vision/visualizing_what_convnets_learn/">Visualizing what convnets learn</a> <a class="nav-sublink2" href="/examples/vision/integrated_gradients/">Model interpretability with Integrated Gradients</a> <a class="nav-sublink2" href="/examples/vision/probing_vits/">Investigating Vision Transformer representations</a> <a class="nav-sublink2" href="/examples/vision/grad_cam/">Grad-CAM class activation visualization</a> <a class="nav-sublink2" href="/examples/vision/near_dup_search/">Near-duplicate image search</a> <a class="nav-sublink2" href="/examples/vision/semantic_image_clustering/">Semantic Image Clustering</a> <a class="nav-sublink2" href="/examples/vision/siamese_contrastive/">Image similarity estimation using a Siamese Network with a contrastive loss</a> <a class="nav-sublink2" href="/examples/vision/siamese_network/">Image similarity estimation using a Siamese Network with a triplet loss</a> <a class="nav-sublink2" href="/examples/vision/metric_learning/">Metric learning for image similarity search</a> <a class="nav-sublink2" href="/examples/vision/metric_learning_tf_similarity/">Metric learning for image similarity search using TensorFlow Similarity</a> <a class="nav-sublink2" href="/examples/vision/nnclr/">Self-supervised contrastive learning with NNCLR</a> <a class="nav-sublink2" href="/examples/vision/video_classification/">Video Classification with a CNN-RNN Architecture</a> <a class="nav-sublink2" href="/examples/vision/conv_lstm/">Next-Frame Video Prediction with Convolutional LSTMs</a> <a class="nav-sublink2" href="/examples/vision/video_transformers/">Video Classification with Transformers</a> <a class="nav-sublink2" href="/examples/vision/vivit/">Video Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/bit/">Image Classification using BigTransfer (BiT)</a> <a class="nav-sublink2" href="/examples/vision/gradient_centralization/">Gradient Centralization for Better Training Performance</a> <a class="nav-sublink2" href="/examples/vision/token_learner/">Learning to tokenize in Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/knowledge_distillation/">Knowledge Distillation</a> <a class="nav-sublink2" href="/examples/vision/fixres/">FixRes: Fixing train-test resolution discrepancy</a> <a class="nav-sublink2" href="/examples/vision/cait/">Class Attention Image Transformers with LayerScale</a> <a class="nav-sublink2" href="/examples/vision/patch_convnet/">Augmenting convnets with aggregated attention</a> <a class="nav-sublink2" href="/examples/vision/learnable_resizer/">Learning to Resize</a> <a class="nav-sublink2" href="/examples/vision/adamatch/">Semi-supervision and domain adaptation with AdaMatch</a> <a class="nav-sublink2" href="/examples/vision/barlow_twins/">Barlow Twins for Contrastive SSL</a> <a class="nav-sublink2" href="/examples/vision/consistency_training/">Consistency training with supervision</a> <a class="nav-sublink2" href="/examples/vision/deit/">Distilling Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/focal_modulation_network/">Focal Modulation: A replacement for Self-Attention</a> <a class="nav-sublink2" href="/examples/vision/forwardforward/">Using the Forward-Forward Algorithm for Image Classification</a> <a class="nav-sublink2" href="/examples/vision/masked_image_modeling/">Masked image modeling with Autoencoders</a> <a class="nav-sublink2" href="/examples/vision/sam/">Segment Anything Model with 🤗Transformers</a> <a class="nav-sublink2" href="/examples/vision/segformer/">Semantic segmentation with SegFormer and Hugging Face Transformers</a> <a class="nav-sublink2" href="/examples/vision/simsiam/">Self-supervised contrastive learning with SimSiam</a> <a class="nav-sublink2" href="/examples/vision/supervised-contrastive-learning/">Supervised Contrastive Learning</a> <a class="nav-sublink2" href="/examples/vision/temporal_latent_bottleneck/">When Recurrence meets Transformers</a> <a class="nav-sublink2" href="/examples/vision/yolov8/">Efficient Object Detection with YOLOV8 and KerasCV</a> <a class="nav-sublink" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparam Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/vision/'>Computer Vision</a> / Train a Vision Transformer on small datasets </div> <div class='k-content'> <h1 id="train-a-vision-transformer-on-small-datasets">Train a Vision Transformer on small datasets</h1> <p><strong>Author:</strong> <a href="https://twitter.com/ariG23498">Aritra Roy Gosthipaty</a><br> <strong>Date created:</strong> 2022/01/07<br> <strong>Last modified:</strong> 2022/01/10<br> <strong>Description:</strong> Training a ViT from scratch on smaller datasets with shifted patch tokenization and locality self-attention.</p> <div class='example_version_banner keras_2'>ⓘ This example uses Keras 2</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/vision/ipynb/vit_small_ds.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/examples/vision/vit_small_ds.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="introduction">Introduction</h2> <p>In the academic paper <a href="https://arxiv.org/abs/2010.11929">An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale</a>, the authors mention that Vision Transformers (ViT) are data-hungry. Therefore, pretraining a ViT on a large-sized dataset like JFT300M and fine-tuning it on medium-sized datasets (like ImageNet) is the only way to beat state-of-the-art Convolutional Neural Network models.</p> <p>The self-attention layer of ViT lacks <strong>locality inductive bias</strong> (the notion that image pixels are locally correlated and that their correlation maps are translation-invariant). This is the reason why ViTs need more data. On the other hand, CNNs look at images through spatial sliding windows, which helps them get better results with smaller datasets.</p> <p>In the academic paper <a href="https://arxiv.org/abs/2112.13492v1">Vision Transformer for Small-Size Datasets</a>, the authors set out to tackle the problem of locality inductive bias in ViTs.</p> <p>The main ideas are:</p> <ul> <li><strong>Shifted Patch Tokenization</strong></li> <li><strong>Locality Self Attention</strong></li> </ul> <p>This example implements the ideas of the paper. A large part of this example is inspired from <a href="https://keras.io/examples/vision/image_classification_with_vision_transformer/">Image classification with Vision Transformer</a>.</p> <p><em>Note</em>: This example requires TensorFlow 2.6 or higher, as well as <a href="https://www.tensorflow.org/addons">TensorFlow Addons</a>, which can be installed using the following command:</p> <div class="codehilite"><pre><span></span><code><span class="n">pip</span> <span class="n">install</span> <span class="o">-</span><span class="n">qq</span> <span class="o">-</span><span class="n">U</span> <span class="n">tensorflow</span><span class="o">-</span><span class="n">addons</span> </code></pre></div> <hr /> <h2 id="setup">Setup</h2> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">math</span> <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> <span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="nn">tf</span> <span class="kn">from</span> <span class="nn">tensorflow</span> <span class="kn">import</span> <span class="n">keras</span> <span class="kn">import</span> <span class="nn">tensorflow_addons</span> <span class="k">as</span> <span class="nn">tfa</span> <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span> <span class="kn">from</span> <span class="nn">tensorflow.keras</span> <span class="kn">import</span> <span class="n">layers</span> <span class="c1"># Setting seed for reproducibiltiy</span> <span class="n">SEED</span> <span class="o">=</span> <span class="mi">42</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">set_random_seed</span><span class="p">(</span><span class="n">SEED</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="prepare-the-data">Prepare the data</h2> <div class="codehilite"><pre><span></span><code><span class="n">NUM_CLASSES</span> <span class="o">=</span> <span class="mi">100</span> <span class="n">INPUT_SHAPE</span> <span class="o">=</span> <span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span> <span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">),</span> <span class="p">(</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">)</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">datasets</span><span class="o">.</span><span class="n">cifar100</span><span class="o">.</span><span class="n">load_data</span><span class="p">()</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"x_train shape: </span><span class="si">{</span><span class="n">x_train</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2"> - y_train shape: </span><span class="si">{</span><span class="n">y_train</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"x_test shape: </span><span class="si">{</span><span class="n">x_test</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2"> - y_test shape: </span><span class="si">{</span><span class="n">y_test</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Downloading data from https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz 169009152/169001437 [==============================] - 16s 0us/step 169017344/169001437 [==============================] - 16s 0us/step x_train shape: (50000, 32, 32, 3) - y_train shape: (50000, 1) x_test shape: (10000, 32, 32, 3) - y_test shape: (10000, 1) </code></pre></div> </div> <hr /> <h2 id="configure-the-hyperparameters">Configure the hyperparameters</h2> <p>The hyperparameters are different from the paper. Feel free to tune the hyperparameters yourself.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># DATA</span> <span class="n">BUFFER_SIZE</span> <span class="o">=</span> <span class="mi">512</span> <span class="n">BATCH_SIZE</span> <span class="o">=</span> <span class="mi">256</span> <span class="c1"># AUGMENTATION</span> <span class="n">IMAGE_SIZE</span> <span class="o">=</span> <span class="mi">72</span> <span class="n">PATCH_SIZE</span> <span class="o">=</span> <span class="mi">6</span> <span class="n">NUM_PATCHES</span> <span class="o">=</span> <span class="p">(</span><span class="n">IMAGE_SIZE</span> <span class="o">//</span> <span class="n">PATCH_SIZE</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span> <span class="c1"># OPTIMIZER</span> <span class="n">LEARNING_RATE</span> <span class="o">=</span> <span class="mf">0.001</span> <span class="n">WEIGHT_DECAY</span> <span class="o">=</span> <span class="mf">0.0001</span> <span class="c1"># TRAINING</span> <span class="n">EPOCHS</span> <span class="o">=</span> <span class="mi">50</span> <span class="c1"># ARCHITECTURE</span> <span class="n">LAYER_NORM_EPS</span> <span class="o">=</span> <span class="mf">1e-6</span> <span class="n">TRANSFORMER_LAYERS</span> <span class="o">=</span> <span class="mi">8</span> <span class="n">PROJECTION_DIM</span> <span class="o">=</span> <span class="mi">64</span> <span class="n">NUM_HEADS</span> <span class="o">=</span> <span class="mi">4</span> <span class="n">TRANSFORMER_UNITS</span> <span class="o">=</span> <span class="p">[</span> <span class="n">PROJECTION_DIM</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span> <span class="n">PROJECTION_DIM</span><span class="p">,</span> <span class="p">]</span> <span class="n">MLP_HEAD_UNITS</span> <span class="o">=</span> <span class="p">[</span><span class="mi">2048</span><span class="p">,</span> <span class="mi">1024</span><span class="p">]</span> </code></pre></div> <hr /> <h2 id="use-data-augmentation">Use data augmentation</h2> <p>A snippet from the paper:</p> <p><em>"According to DeiT, various techniques are required to effectively train ViTs. Thus, we applied data augmentations such as CutMix, Mixup, Auto Augment, Repeated Augment to all models."</em></p> <p>In this example, we will focus solely on the novelty of the approach and not on reproducing the paper results. For this reason, we don't use the mentioned data augmentation schemes. Please feel free to add to or remove from the augmentation pipeline.</p> <div class="codehilite"><pre><span></span><code><span class="n">data_augmentation</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span> <span class="p">[</span> <span class="n">layers</span><span class="o">.</span><span class="n">Normalization</span><span class="p">(),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Resizing</span><span class="p">(</span><span class="n">IMAGE_SIZE</span><span class="p">,</span> <span class="n">IMAGE_SIZE</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">RandomFlip</span><span class="p">(</span><span class="s2">"horizontal"</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">RandomRotation</span><span class="p">(</span><span class="n">factor</span><span class="o">=</span><span class="mf">0.02</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">RandomZoom</span><span class="p">(</span><span class="n">height_factor</span><span class="o">=</span><span class="mf">0.2</span><span class="p">,</span> <span class="n">width_factor</span><span class="o">=</span><span class="mf">0.2</span><span class="p">),</span> <span class="p">],</span> <span class="n">name</span><span class="o">=</span><span class="s2">"data_augmentation"</span><span class="p">,</span> <span class="p">)</span> <span class="c1"># Compute the mean and the variance of the training data for normalization.</span> <span class="n">data_augmentation</span><span class="o">.</span><span class="n">layers</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">adapt</span><span class="p">(</span><span class="n">x_train</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="implement-shifted-patch-tokenization">Implement Shifted Patch Tokenization</h2> <p>In a ViT pipeline, the input images are divided into patches that are then linearly projected into tokens. Shifted patch tokenization (STP) is introduced to combat the low receptive field of ViTs. The steps for Shifted Patch Tokenization are as follows:</p> <ul> <li>Start with an image.</li> <li>Shift the image in diagonal directions.</li> <li>Concat the diagonally shifted images with the original image.</li> <li>Extract patches of the concatenated images.</li> <li>Flatten the spatial dimension of all patches.</li> <li>Layer normalize the flattened patches and then project it.</li> </ul> <table> <thead> <tr> <th style="text-align: center;"><img alt="Shifted Patch Toekenization" src="https://i.imgur.com/bUnHxd0.png" /></th> </tr> </thead> <tbody> <tr> <td style="text-align: center;">Shifted Patch Tokenization <a href="https://arxiv.org/abs/2112.13492v1">Source</a></td> </tr> </tbody> </table> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">ShiftedPatchTokenization</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> <span class="bp">self</span><span class="p">,</span> <span class="n">image_size</span><span class="o">=</span><span class="n">IMAGE_SIZE</span><span class="p">,</span> <span class="n">patch_size</span><span class="o">=</span><span class="n">PATCH_SIZE</span><span class="p">,</span> <span class="n">num_patches</span><span class="o">=</span><span class="n">NUM_PATCHES</span><span class="p">,</span> <span class="n">projection_dim</span><span class="o">=</span><span class="n">PROJECTION_DIM</span><span class="p">,</span> <span class="n">vanilla</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">,</span> <span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">vanilla</span> <span class="o">=</span> <span class="n">vanilla</span> <span class="c1"># Flag to swtich to vanilla patch extractor</span> <span class="bp">self</span><span class="o">.</span><span class="n">image_size</span> <span class="o">=</span> <span class="n">image_size</span> <span class="bp">self</span><span class="o">.</span><span class="n">patch_size</span> <span class="o">=</span> <span class="n">patch_size</span> <span class="bp">self</span><span class="o">.</span><span class="n">half_patch</span> <span class="o">=</span> <span class="n">patch_size</span> <span class="o">//</span> <span class="mi">2</span> <span class="bp">self</span><span class="o">.</span><span class="n">flatten_patches</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Reshape</span><span class="p">((</span><span class="n">num_patches</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">))</span> <span class="bp">self</span><span class="o">.</span><span class="n">projection</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">units</span><span class="o">=</span><span class="n">projection_dim</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_norm</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">LayerNormalization</span><span class="p">(</span><span class="n">epsilon</span><span class="o">=</span><span class="n">LAYER_NORM_EPS</span><span class="p">)</span> <span class="k">def</span> <span class="nf">crop_shift_pad</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">images</span><span class="p">,</span> <span class="n">mode</span><span class="p">):</span> <span class="c1"># Build the diagonally shifted images</span> <span class="k">if</span> <span class="n">mode</span> <span class="o">==</span> <span class="s2">"left-up"</span><span class="p">:</span> <span class="n">crop_height</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">half_patch</span> <span class="n">crop_width</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">half_patch</span> <span class="n">shift_height</span> <span class="o">=</span> <span class="mi">0</span> <span class="n">shift_width</span> <span class="o">=</span> <span class="mi">0</span> <span class="k">elif</span> <span class="n">mode</span> <span class="o">==</span> <span class="s2">"left-down"</span><span class="p">:</span> <span class="n">crop_height</span> <span class="o">=</span> <span class="mi">0</span> <span class="n">crop_width</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">half_patch</span> <span class="n">shift_height</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">half_patch</span> <span class="n">shift_width</span> <span class="o">=</span> <span class="mi">0</span> <span class="k">elif</span> <span class="n">mode</span> <span class="o">==</span> <span class="s2">"right-up"</span><span class="p">:</span> <span class="n">crop_height</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">half_patch</span> <span class="n">crop_width</span> <span class="o">=</span> <span class="mi">0</span> <span class="n">shift_height</span> <span class="o">=</span> <span class="mi">0</span> <span class="n">shift_width</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">half_patch</span> <span class="k">else</span><span class="p">:</span> <span class="n">crop_height</span> <span class="o">=</span> <span class="mi">0</span> <span class="n">crop_width</span> <span class="o">=</span> <span class="mi">0</span> <span class="n">shift_height</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">half_patch</span> <span class="n">shift_width</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">half_patch</span> <span class="c1"># Crop the shifted images and pad them</span> <span class="n">crop</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">crop_to_bounding_box</span><span class="p">(</span> <span class="n">images</span><span class="p">,</span> <span class="n">offset_height</span><span class="o">=</span><span class="n">crop_height</span><span class="p">,</span> <span class="n">offset_width</span><span class="o">=</span><span class="n">crop_width</span><span class="p">,</span> <span class="n">target_height</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">image_size</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">half_patch</span><span class="p">,</span> <span class="n">target_width</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">image_size</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">half_patch</span><span class="p">,</span> <span class="p">)</span> <span class="n">shift_pad</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">pad_to_bounding_box</span><span class="p">(</span> <span class="n">crop</span><span class="p">,</span> <span class="n">offset_height</span><span class="o">=</span><span class="n">shift_height</span><span class="p">,</span> <span class="n">offset_width</span><span class="o">=</span><span class="n">shift_width</span><span class="p">,</span> <span class="n">target_height</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">image_size</span><span class="p">,</span> <span class="n">target_width</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">image_size</span><span class="p">,</span> <span class="p">)</span> <span class="k">return</span> <span class="n">shift_pad</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">images</span><span class="p">):</span> <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">vanilla</span><span class="p">:</span> <span class="c1"># Concat the shifted images with the original image</span> <span class="n">images</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">concat</span><span class="p">(</span> <span class="p">[</span> <span class="n">images</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">crop_shift_pad</span><span class="p">(</span><span class="n">images</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s2">"left-up"</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">crop_shift_pad</span><span class="p">(</span><span class="n">images</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s2">"left-down"</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">crop_shift_pad</span><span class="p">(</span><span class="n">images</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s2">"right-up"</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">crop_shift_pad</span><span class="p">(</span><span class="n">images</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s2">"right-down"</span><span class="p">),</span> <span class="p">],</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="p">)</span> <span class="c1"># Patchify the images and flatten it</span> <span class="n">patches</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">extract_patches</span><span class="p">(</span> <span class="n">images</span><span class="o">=</span><span class="n">images</span><span class="p">,</span> <span class="n">sizes</span><span class="o">=</span><span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">patch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">patch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">strides</span><span class="o">=</span><span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">patch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">patch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">rates</span><span class="o">=</span><span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"VALID"</span><span class="p">,</span> <span class="p">)</span> <span class="n">flat_patches</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">flatten_patches</span><span class="p">(</span><span class="n">patches</span><span class="p">)</span> <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">vanilla</span><span class="p">:</span> <span class="c1"># Layer normalize the flat patches and linearly project it</span> <span class="n">tokens</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_norm</span><span class="p">(</span><span class="n">flat_patches</span><span class="p">)</span> <span class="n">tokens</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">projection</span><span class="p">(</span><span class="n">tokens</span><span class="p">)</span> <span class="k">else</span><span class="p">:</span> <span class="c1"># Linearly project the flat patches</span> <span class="n">tokens</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">projection</span><span class="p">(</span><span class="n">flat_patches</span><span class="p">)</span> <span class="k">return</span> <span class="p">(</span><span class="n">tokens</span><span class="p">,</span> <span class="n">patches</span><span class="p">)</span> </code></pre></div> <h3 id="visualize-the-patches">Visualize the patches</h3> <div class="codehilite"><pre><span></span><code><span class="c1"># Get a random image from the training dataset</span> <span class="c1"># and resize the image</span> <span class="n">image</span> <span class="o">=</span> <span class="n">x_train</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">x_train</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]))]</span> <span class="n">resized_image</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">resize</span><span class="p">(</span> <span class="n">tf</span><span class="o">.</span><span class="n">convert_to_tensor</span><span class="p">([</span><span class="n">image</span><span class="p">]),</span> <span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="n">IMAGE_SIZE</span><span class="p">,</span> <span class="n">IMAGE_SIZE</span><span class="p">)</span> <span class="p">)</span> <span class="c1"># Vanilla patch maker: This takes an image and divides into</span> <span class="c1"># patches as in the original ViT paper</span> <span class="p">(</span><span class="n">token</span><span class="p">,</span> <span class="n">patch</span><span class="p">)</span> <span class="o">=</span> <span class="n">ShiftedPatchTokenization</span><span class="p">(</span><span class="n">vanilla</span><span class="o">=</span><span class="kc">True</span><span class="p">)(</span><span class="n">resized_image</span> <span class="o">/</span> <span class="mf">255.0</span><span class="p">)</span> <span class="p">(</span><span class="n">token</span><span class="p">,</span> <span class="n">patch</span><span class="p">)</span> <span class="o">=</span> <span class="p">(</span><span class="n">token</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">patch</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="n">n</span> <span class="o">=</span> <span class="n">patch</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="n">count</span> <span class="o">=</span> <span class="mi">1</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">))</span> <span class="k">for</span> <span class="n">row</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n</span><span class="p">):</span> <span class="k">for</span> <span class="n">col</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n</span><span class="p">):</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplot</span><span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="n">n</span><span class="p">,</span> <span class="n">count</span><span class="p">)</span> <span class="n">count</span> <span class="o">=</span> <span class="n">count</span> <span class="o">+</span> <span class="mi">1</span> <span class="n">image</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">patch</span><span class="p">[</span><span class="n">row</span><span class="p">][</span><span class="n">col</span><span class="p">],</span> <span class="p">(</span><span class="n">PATCH_SIZE</span><span class="p">,</span> <span class="n">PATCH_SIZE</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span> <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">image</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">"off"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> <span class="c1"># Shifted Patch Tokenization: This layer takes the image, shifts it</span> <span class="c1"># diagonally and then extracts patches from the concatinated images</span> <span class="p">(</span><span class="n">token</span><span class="p">,</span> <span class="n">patch</span><span class="p">)</span> <span class="o">=</span> <span class="n">ShiftedPatchTokenization</span><span class="p">(</span><span class="n">vanilla</span><span class="o">=</span><span class="kc">False</span><span class="p">)(</span><span class="n">resized_image</span> <span class="o">/</span> <span class="mf">255.0</span><span class="p">)</span> <span class="p">(</span><span class="n">token</span><span class="p">,</span> <span class="n">patch</span><span class="p">)</span> <span class="o">=</span> <span class="p">(</span><span class="n">token</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">patch</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="n">n</span> <span class="o">=</span> <span class="n">patch</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="n">shifted_images</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"ORIGINAL"</span><span class="p">,</span> <span class="s2">"LEFT-UP"</span><span class="p">,</span> <span class="s2">"LEFT-DOWN"</span><span class="p">,</span> <span class="s2">"RIGHT-UP"</span><span class="p">,</span> <span class="s2">"RIGHT-DOWN"</span><span class="p">]</span> <span class="k">for</span> <span class="n">index</span><span class="p">,</span> <span class="n">name</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">shifted_images</span><span class="p">):</span> <span class="nb">print</span><span class="p">(</span><span class="n">name</span><span class="p">)</span> <span class="n">count</span> <span class="o">=</span> <span class="mi">1</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">))</span> <span class="k">for</span> <span class="n">row</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n</span><span class="p">):</span> <span class="k">for</span> <span class="n">col</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">n</span><span class="p">):</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplot</span><span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="n">n</span><span class="p">,</span> <span class="n">count</span><span class="p">)</span> <span class="n">count</span> <span class="o">=</span> <span class="n">count</span> <span class="o">+</span> <span class="mi">1</span> <span class="n">image</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">patch</span><span class="p">[</span><span class="n">row</span><span class="p">][</span><span class="n">col</span><span class="p">],</span> <span class="p">(</span><span class="n">PATCH_SIZE</span><span class="p">,</span> <span class="n">PATCH_SIZE</span><span class="p">,</span> <span class="mi">5</span> <span class="o">*</span> <span class="mi">3</span><span class="p">))</span> <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">image</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">3</span> <span class="o">*</span> <span class="n">index</span> <span class="p">:</span> <span class="mi">3</span> <span class="o">*</span> <span class="n">index</span> <span class="o">+</span> <span class="mi">3</span><span class="p">])</span> <span class="n">plt</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">"off"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>2022-01-12 04:50:54.960908: I tensorflow/stream_executor/cuda/cuda_blas.cc:1774] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once. </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/vit_small_ds/vit_small_ds_13_1.png" /></p> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>ORIGINAL </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/vit_small_ds/vit_small_ds_13_3.png" /></p> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>LEFT-UP </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/vit_small_ds/vit_small_ds_13_5.png" /></p> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>LEFT-DOWN </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/vit_small_ds/vit_small_ds_13_7.png" /></p> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>RIGHT-UP </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/vit_small_ds/vit_small_ds_13_9.png" /></p> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>RIGHT-DOWN </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/vit_small_ds/vit_small_ds_13_11.png" /></p> <hr /> <h2 id="implement-the-patch-encoding-layer">Implement the patch encoding layer</h2> <p>This layer accepts projected patches and then adds positional information to them.</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">PatchEncoder</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> <span class="bp">self</span><span class="p">,</span> <span class="n">num_patches</span><span class="o">=</span><span class="n">NUM_PATCHES</span><span class="p">,</span> <span class="n">projection_dim</span><span class="o">=</span><span class="n">PROJECTION_DIM</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span> <span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_patches</span> <span class="o">=</span> <span class="n">num_patches</span> <span class="bp">self</span><span class="o">.</span><span class="n">position_embedding</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span> <span class="n">input_dim</span><span class="o">=</span><span class="n">num_patches</span><span class="p">,</span> <span class="n">output_dim</span><span class="o">=</span><span class="n">projection_dim</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">positions</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">range</span><span class="p">(</span><span class="n">start</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">limit</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">num_patches</span><span class="p">,</span> <span class="n">delta</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">encoded_patches</span><span class="p">):</span> <span class="n">encoded_positions</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">position_embedding</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">positions</span><span class="p">)</span> <span class="n">encoded_patches</span> <span class="o">=</span> <span class="n">encoded_patches</span> <span class="o">+</span> <span class="n">encoded_positions</span> <span class="k">return</span> <span class="n">encoded_patches</span> </code></pre></div> <hr /> <h2 id="implement-locality-self-attention">Implement Locality Self Attention</h2> <p>The regular attention equation is stated below.</p> <table> <thead> <tr> <th style="text-align: center;"><img alt="Equation of attention" src="https://miro.medium.com/max/396/1*P9sV1xXM10t943bXy_G9yg.png" /></th> </tr> </thead> <tbody> <tr> <td style="text-align: center;"><a href="https://towardsdatascience.com/attention-is-all-you-need-discovering-the-transformer-paper-73e5ff5e0634">Source</a></td> </tr> </tbody> </table> <p>The attention module takes a query, key, and value. First, we compute the similarity between the query and key via a dot product. Then, the result is scaled by the square root of the key dimension. The scaling prevents the softmax function from having an overly small gradient. Softmax is then applied to the scaled dot product to produce the attention weights. The value is then modulated via the attention weights.</p> <p>In self-attention, query, key and value come from the same input. The dot product would result in large self-token relations rather than inter-token relations. This also means that the softmax gives higher probabilities to self-token relations than the inter-token relations. To combat this, the authors propose masking the diagonal of the dot product. This way, we force the attention module to pay more attention to the inter-token relations.</p> <p>The scaling factor is a constant in the regular attention module. This acts like a temperature term that can modulate the softmax function. The authors suggest a learnable temperature term instead of a constant.</p> <table> <thead> <tr> <th style="text-align: center;"><img alt="Implementation of LSA" src="https://i.imgur.com/GTV99pk.png" /></th> </tr> </thead> <tbody> <tr> <td style="text-align: center;">Locality Self Attention <a href="https://arxiv.org/abs/2112.13492v1">Source</a></td> </tr> </tbody> </table> <p>The above two pointers make the Locality Self Attention. We have subclassed the <a href="https://www.tensorflow.org/api_docs/python/tf/keras/layers/MultiHeadAttention"><code>layers.MultiHeadAttention</code></a> and implemented the trainable temperature. The attention mask is built at a later stage.</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">MultiHeadAttentionLSA</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">MultiHeadAttention</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="c1"># The trainable temperature term. The initial value is</span> <span class="c1"># the square root of the key dimension.</span> <span class="bp">self</span><span class="o">.</span><span class="n">tau</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">Variable</span><span class="p">(</span><span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="nb">float</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_key_dim</span><span class="p">)),</span> <span class="n">trainable</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="k">def</span> <span class="nf">_compute_attention</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="p">,</span> <span class="n">attention_mask</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="n">query</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">multiply</span><span class="p">(</span><span class="n">query</span><span class="p">,</span> <span class="mf">1.0</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">tau</span><span class="p">)</span> <span class="n">attention_scores</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_dot_product_equation</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">query</span><span class="p">)</span> <span class="n">attention_scores</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_masked_softmax</span><span class="p">(</span><span class="n">attention_scores</span><span class="p">,</span> <span class="n">attention_mask</span><span class="p">)</span> <span class="n">attention_scores_dropout</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_dropout_layer</span><span class="p">(</span> <span class="n">attention_scores</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="n">training</span> <span class="p">)</span> <span class="n">attention_output</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span> <span class="bp">self</span><span class="o">.</span><span class="n">_combine_equation</span><span class="p">,</span> <span class="n">attention_scores_dropout</span><span class="p">,</span> <span class="n">value</span> <span class="p">)</span> <span class="k">return</span> <span class="n">attention_output</span><span class="p">,</span> <span class="n">attention_scores</span> </code></pre></div> <hr /> <h2 id="implement-the-mlp">Implement the MLP</h2> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">mlp</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">hidden_units</span><span class="p">,</span> <span class="n">dropout_rate</span><span class="p">):</span> <span class="k">for</span> <span class="n">units</span> <span class="ow">in</span> <span class="n">hidden_units</span><span class="p">:</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">units</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">gelu</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout_rate</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">x</span> <span class="c1"># Build the diagonal attention mask</span> <span class="n">diag_attn_mask</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">tf</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span><span class="n">NUM_PATCHES</span><span class="p">)</span> <span class="n">diag_attn_mask</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">cast</span><span class="p">([</span><span class="n">diag_attn_mask</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">int8</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="build-the-vit">Build the ViT</h2> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">create_vit_classifier</span><span class="p">(</span><span class="n">vanilla</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="n">INPUT_SHAPE</span><span class="p">)</span> <span class="c1"># Augment data.</span> <span class="n">augmented</span> <span class="o">=</span> <span class="n">data_augmentation</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="c1"># Create patches.</span> <span class="p">(</span><span class="n">tokens</span><span class="p">,</span> <span class="n">_</span><span class="p">)</span> <span class="o">=</span> <span class="n">ShiftedPatchTokenization</span><span class="p">(</span><span class="n">vanilla</span><span class="o">=</span><span class="n">vanilla</span><span class="p">)(</span><span class="n">augmented</span><span class="p">)</span> <span class="c1"># Encode patches.</span> <span class="n">encoded_patches</span> <span class="o">=</span> <span class="n">PatchEncoder</span><span class="p">()(</span><span class="n">tokens</span><span class="p">)</span> <span class="c1"># Create multiple layers of the Transformer block.</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">TRANSFORMER_LAYERS</span><span class="p">):</span> <span class="c1"># Layer normalization 1.</span> <span class="n">x1</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">LayerNormalization</span><span class="p">(</span><span class="n">epsilon</span><span class="o">=</span><span class="mf">1e-6</span><span class="p">)(</span><span class="n">encoded_patches</span><span class="p">)</span> <span class="c1"># Create a multi-head attention layer.</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">vanilla</span><span class="p">:</span> <span class="n">attention_output</span> <span class="o">=</span> <span class="n">MultiHeadAttentionLSA</span><span class="p">(</span> <span class="n">num_heads</span><span class="o">=</span><span class="n">NUM_HEADS</span><span class="p">,</span> <span class="n">key_dim</span><span class="o">=</span><span class="n">PROJECTION_DIM</span><span class="p">,</span> <span class="n">dropout</span><span class="o">=</span><span class="mf">0.1</span> <span class="p">)(</span><span class="n">x1</span><span class="p">,</span> <span class="n">x1</span><span class="p">,</span> <span class="n">attention_mask</span><span class="o">=</span><span class="n">diag_attn_mask</span><span class="p">)</span> <span class="k">else</span><span class="p">:</span> <span class="n">attention_output</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">MultiHeadAttention</span><span class="p">(</span> <span class="n">num_heads</span><span class="o">=</span><span class="n">NUM_HEADS</span><span class="p">,</span> <span class="n">key_dim</span><span class="o">=</span><span class="n">PROJECTION_DIM</span><span class="p">,</span> <span class="n">dropout</span><span class="o">=</span><span class="mf">0.1</span> <span class="p">)(</span><span class="n">x1</span><span class="p">,</span> <span class="n">x1</span><span class="p">)</span> <span class="c1"># Skip connection 1.</span> <span class="n">x2</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Add</span><span class="p">()([</span><span class="n">attention_output</span><span class="p">,</span> <span class="n">encoded_patches</span><span class="p">])</span> <span class="c1"># Layer normalization 2.</span> <span class="n">x3</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">LayerNormalization</span><span class="p">(</span><span class="n">epsilon</span><span class="o">=</span><span class="mf">1e-6</span><span class="p">)(</span><span class="n">x2</span><span class="p">)</span> <span class="c1"># MLP.</span> <span class="n">x3</span> <span class="o">=</span> <span class="n">mlp</span><span class="p">(</span><span class="n">x3</span><span class="p">,</span> <span class="n">hidden_units</span><span class="o">=</span><span class="n">TRANSFORMER_UNITS</span><span class="p">,</span> <span class="n">dropout_rate</span><span class="o">=</span><span class="mf">0.1</span><span class="p">)</span> <span class="c1"># Skip connection 2.</span> <span class="n">encoded_patches</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Add</span><span class="p">()([</span><span class="n">x3</span><span class="p">,</span> <span class="n">x2</span><span class="p">])</span> <span class="c1"># Create a [batch_size, projection_dim] tensor.</span> <span class="n">representation</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">LayerNormalization</span><span class="p">(</span><span class="n">epsilon</span><span class="o">=</span><span class="mf">1e-6</span><span class="p">)(</span><span class="n">encoded_patches</span><span class="p">)</span> <span class="n">representation</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Flatten</span><span class="p">()(</span><span class="n">representation</span><span class="p">)</span> <span class="n">representation</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="mf">0.5</span><span class="p">)(</span><span class="n">representation</span><span class="p">)</span> <span class="c1"># Add MLP.</span> <span class="n">features</span> <span class="o">=</span> <span class="n">mlp</span><span class="p">(</span><span class="n">representation</span><span class="p">,</span> <span class="n">hidden_units</span><span class="o">=</span><span class="n">MLP_HEAD_UNITS</span><span class="p">,</span> <span class="n">dropout_rate</span><span class="o">=</span><span class="mf">0.5</span><span class="p">)</span> <span class="c1"># Classify outputs.</span> <span class="n">logits</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">NUM_CLASSES</span><span class="p">)(</span><span class="n">features</span><span class="p">)</span> <span class="c1"># Create the Keras model.</span> <span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="o">=</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="o">=</span><span class="n">logits</span><span class="p">)</span> <span class="k">return</span> <span class="n">model</span> </code></pre></div> <hr /> <h2 id="compile-train-and-evaluate-the-mode">Compile, train, and evaluate the mode</h2> <div class="codehilite"><pre><span></span><code><span class="c1"># Some code is taken from:</span> <span class="c1"># https://www.kaggle.com/ashusma/training-rfcx-tensorflow-tpu-effnet-b2.</span> <span class="k">class</span> <span class="nc">WarmUpCosine</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">schedules</span><span class="o">.</span><span class="n">LearningRateSchedule</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> <span class="bp">self</span><span class="p">,</span> <span class="n">learning_rate_base</span><span class="p">,</span> <span class="n">total_steps</span><span class="p">,</span> <span class="n">warmup_learning_rate</span><span class="p">,</span> <span class="n">warmup_steps</span> <span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">learning_rate_base</span> <span class="o">=</span> <span class="n">learning_rate_base</span> <span class="bp">self</span><span class="o">.</span><span class="n">total_steps</span> <span class="o">=</span> <span class="n">total_steps</span> <span class="bp">self</span><span class="o">.</span><span class="n">warmup_learning_rate</span> <span class="o">=</span> <span class="n">warmup_learning_rate</span> <span class="bp">self</span><span class="o">.</span><span class="n">warmup_steps</span> <span class="o">=</span> <span class="n">warmup_steps</span> <span class="bp">self</span><span class="o">.</span><span class="n">pi</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">constant</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">pi</span><span class="p">)</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">step</span><span class="p">):</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">total_steps</span> <span class="o"><</span> <span class="bp">self</span><span class="o">.</span><span class="n">warmup_steps</span><span class="p">:</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"Total_steps must be larger or equal to warmup_steps."</span><span class="p">)</span> <span class="n">cos_annealed_lr</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">cos</span><span class="p">(</span> <span class="bp">self</span><span class="o">.</span><span class="n">pi</span> <span class="o">*</span> <span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">step</span><span class="p">,</span> <span class="n">tf</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">warmup_steps</span><span class="p">)</span> <span class="o">/</span> <span class="nb">float</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">total_steps</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">warmup_steps</span><span class="p">)</span> <span class="p">)</span> <span class="n">learning_rate</span> <span class="o">=</span> <span class="mf">0.5</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">learning_rate_base</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">+</span> <span class="n">cos_annealed_lr</span><span class="p">)</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">warmup_steps</span> <span class="o">></span> <span class="mi">0</span><span class="p">:</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">learning_rate_base</span> <span class="o"><</span> <span class="bp">self</span><span class="o">.</span><span class="n">warmup_learning_rate</span><span class="p">:</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span> <span class="s2">"Learning_rate_base must be larger or equal to "</span> <span class="s2">"warmup_learning_rate."</span> <span class="p">)</span> <span class="n">slope</span> <span class="o">=</span> <span class="p">(</span> <span class="bp">self</span><span class="o">.</span><span class="n">learning_rate_base</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">warmup_learning_rate</span> <span class="p">)</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">warmup_steps</span> <span class="n">warmup_rate</span> <span class="o">=</span> <span class="n">slope</span> <span class="o">*</span> <span class="n">tf</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">step</span><span class="p">,</span> <span class="n">tf</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">warmup_learning_rate</span> <span class="n">learning_rate</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">where</span><span class="p">(</span> <span class="n">step</span> <span class="o"><</span> <span class="bp">self</span><span class="o">.</span><span class="n">warmup_steps</span><span class="p">,</span> <span class="n">warmup_rate</span><span class="p">,</span> <span class="n">learning_rate</span> <span class="p">)</span> <span class="k">return</span> <span class="n">tf</span><span class="o">.</span><span class="n">where</span><span class="p">(</span> <span class="n">step</span> <span class="o">></span> <span class="bp">self</span><span class="o">.</span><span class="n">total_steps</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="n">learning_rate</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"learning_rate"</span> <span class="p">)</span> <span class="k">def</span> <span class="nf">run_experiment</span><span class="p">(</span><span class="n">model</span><span class="p">):</span> <span class="n">total_steps</span> <span class="o">=</span> <span class="nb">int</span><span class="p">((</span><span class="nb">len</span><span class="p">(</span><span class="n">x_train</span><span class="p">)</span> <span class="o">/</span> <span class="n">BATCH_SIZE</span><span class="p">)</span> <span class="o">*</span> <span class="n">EPOCHS</span><span class="p">)</span> <span class="n">warmup_epoch_percentage</span> <span class="o">=</span> <span class="mf">0.10</span> <span class="n">warmup_steps</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">total_steps</span> <span class="o">*</span> <span class="n">warmup_epoch_percentage</span><span class="p">)</span> <span class="n">scheduled_lrs</span> <span class="o">=</span> <span class="n">WarmUpCosine</span><span class="p">(</span> <span class="n">learning_rate_base</span><span class="o">=</span><span class="n">LEARNING_RATE</span><span class="p">,</span> <span class="n">total_steps</span><span class="o">=</span><span class="n">total_steps</span><span class="p">,</span> <span class="n">warmup_learning_rate</span><span class="o">=</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">warmup_steps</span><span class="o">=</span><span class="n">warmup_steps</span><span class="p">,</span> <span class="p">)</span> <span class="n">optimizer</span> <span class="o">=</span> <span class="n">tfa</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">AdamW</span><span class="p">(</span> <span class="n">learning_rate</span><span class="o">=</span><span class="n">LEARNING_RATE</span><span class="p">,</span> <span class="n">weight_decay</span><span class="o">=</span><span class="n">WEIGHT_DECAY</span> <span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">optimizer</span><span class="p">,</span> <span class="n">loss</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">SparseCategoricalCrossentropy</span><span class="p">(</span><span class="n">from_logits</span><span class="o">=</span><span class="kc">True</span><span class="p">),</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span> <span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">SparseCategoricalAccuracy</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">"accuracy"</span><span class="p">),</span> <span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">SparseTopKCategoricalAccuracy</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"top-5-accuracy"</span><span class="p">),</span> <span class="p">],</span> <span class="p">)</span> <span class="n">history</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span> <span class="n">x</span><span class="o">=</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y</span><span class="o">=</span><span class="n">y_train</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">BATCH_SIZE</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="n">EPOCHS</span><span class="p">,</span> <span class="n">validation_split</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> <span class="p">)</span> <span class="n">_</span><span class="p">,</span> <span class="n">accuracy</span><span class="p">,</span> <span class="n">top_5_accuracy</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">BATCH_SIZE</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Test accuracy: </span><span class="si">{</span><span class="nb">round</span><span class="p">(</span><span class="n">accuracy</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="mi">100</span><span class="p">,</span><span class="w"> </span><span class="mi">2</span><span class="p">)</span><span class="si">}</span><span class="s2">%"</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Test top 5 accuracy: </span><span class="si">{</span><span class="nb">round</span><span class="p">(</span><span class="n">top_5_accuracy</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="mi">100</span><span class="p">,</span><span class="w"> </span><span class="mi">2</span><span class="p">)</span><span class="si">}</span><span class="s2">%"</span><span class="p">)</span> <span class="k">return</span> <span class="n">history</span> <span class="c1"># Run experiments with the vanilla ViT</span> <span class="n">vit</span> <span class="o">=</span> <span class="n">create_vit_classifier</span><span class="p">(</span><span class="n">vanilla</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="n">history</span> <span class="o">=</span> <span class="n">run_experiment</span><span class="p">(</span><span class="n">vit</span><span class="p">)</span> <span class="c1"># Run experiments with the Shifted Patch Tokenization and</span> <span class="c1"># Locality Self Attention modified ViT</span> <span class="n">vit_sl</span> <span class="o">=</span> <span class="n">create_vit_classifier</span><span class="p">(</span><span class="n">vanilla</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> <span class="n">history</span> <span class="o">=</span> <span class="n">run_experiment</span><span class="p">(</span><span class="n">vit_sl</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 1/50 176/176 [==============================] - 22s 83ms/step - loss: 4.4912 - accuracy: 0.0427 - top-5-accuracy: 0.1549 - val_loss: 3.9409 - val_accuracy: 0.1030 - val_top-5-accuracy: 0.3036 Epoch 2/50 176/176 [==============================] - 14s 77ms/step - loss: 3.9749 - accuracy: 0.0897 - top-5-accuracy: 0.2802 - val_loss: 3.5721 - val_accuracy: 0.1550 - val_top-5-accuracy: 0.4058 Epoch 3/50 176/176 [==============================] - 14s 77ms/step - loss: 3.7129 - accuracy: 0.1282 - top-5-accuracy: 0.3601 - val_loss: 3.3235 - val_accuracy: 0.2022 - val_top-5-accuracy: 0.4788 Epoch 4/50 176/176 [==============================] - 14s 77ms/step - loss: 3.5518 - accuracy: 0.1544 - top-5-accuracy: 0.4078 - val_loss: 3.2432 - val_accuracy: 0.2132 - val_top-5-accuracy: 0.5056 Epoch 5/50 176/176 [==============================] - 14s 77ms/step - loss: 3.4098 - accuracy: 0.1828 - top-5-accuracy: 0.4471 - val_loss: 3.0910 - val_accuracy: 0.2462 - val_top-5-accuracy: 0.5376 Epoch 6/50 176/176 [==============================] - 14s 77ms/step - loss: 3.2835 - accuracy: 0.2037 - top-5-accuracy: 0.4838 - val_loss: 2.9803 - val_accuracy: 0.2704 - val_top-5-accuracy: 0.5606 Epoch 7/50 176/176 [==============================] - 14s 77ms/step - loss: 3.1756 - accuracy: 0.2205 - top-5-accuracy: 0.5113 - val_loss: 2.8608 - val_accuracy: 0.2802 - val_top-5-accuracy: 0.5908 Epoch 8/50 176/176 [==============================] - 14s 77ms/step - loss: 3.0585 - accuracy: 0.2439 - top-5-accuracy: 0.5432 - val_loss: 2.8055 - val_accuracy: 0.2960 - val_top-5-accuracy: 0.6144 Epoch 9/50 176/176 [==============================] - 14s 77ms/step - loss: 2.9457 - accuracy: 0.2654 - top-5-accuracy: 0.5697 - val_loss: 2.7034 - val_accuracy: 0.3210 - val_top-5-accuracy: 0.6242 Epoch 10/50 176/176 [==============================] - 14s 77ms/step - loss: 2.8458 - accuracy: 0.2863 - top-5-accuracy: 0.5918 - val_loss: 2.5899 - val_accuracy: 0.3416 - val_top-5-accuracy: 0.6500 Epoch 11/50 176/176 [==============================] - 14s 77ms/step - loss: 2.7530 - accuracy: 0.3052 - top-5-accuracy: 0.6191 - val_loss: 2.5275 - val_accuracy: 0.3526 - val_top-5-accuracy: 0.6660 Epoch 12/50 176/176 [==============================] - 14s 77ms/step - loss: 2.6561 - accuracy: 0.3250 - top-5-accuracy: 0.6355 - val_loss: 2.5111 - val_accuracy: 0.3544 - val_top-5-accuracy: 0.6554 Epoch 13/50 176/176 [==============================] - 14s 77ms/step - loss: 2.5833 - accuracy: 0.3398 - top-5-accuracy: 0.6538 - val_loss: 2.3931 - val_accuracy: 0.3792 - val_top-5-accuracy: 0.6888 Epoch 14/50 176/176 [==============================] - 14s 77ms/step - loss: 2.4988 - accuracy: 0.3594 - top-5-accuracy: 0.6724 - val_loss: 2.3695 - val_accuracy: 0.3868 - val_top-5-accuracy: 0.6958 Epoch 15/50 176/176 [==============================] - 14s 77ms/step - loss: 2.4342 - accuracy: 0.3706 - top-5-accuracy: 0.6877 - val_loss: 2.3076 - val_accuracy: 0.4072 - val_top-5-accuracy: 0.7074 Epoch 16/50 176/176 [==============================] - 14s 77ms/step - loss: 2.3654 - accuracy: 0.3841 - top-5-accuracy: 0.7024 - val_loss: 2.2346 - val_accuracy: 0.4202 - val_top-5-accuracy: 0.7174 Epoch 17/50 176/176 [==============================] - 14s 77ms/step - loss: 2.3062 - accuracy: 0.3967 - top-5-accuracy: 0.7130 - val_loss: 2.2277 - val_accuracy: 0.4206 - val_top-5-accuracy: 0.7190 Epoch 18/50 176/176 [==============================] - 14s 77ms/step - loss: 2.2415 - accuracy: 0.4100 - top-5-accuracy: 0.7271 - val_loss: 2.1605 - val_accuracy: 0.4398 - val_top-5-accuracy: 0.7366 Epoch 19/50 176/176 [==============================] - 14s 77ms/step - loss: 2.1802 - accuracy: 0.4240 - top-5-accuracy: 0.7386 - val_loss: 2.1533 - val_accuracy: 0.4428 - val_top-5-accuracy: 0.7382 Epoch 20/50 176/176 [==============================] - 14s 77ms/step - loss: 2.1264 - accuracy: 0.4357 - top-5-accuracy: 0.7486 - val_loss: 2.1395 - val_accuracy: 0.4428 - val_top-5-accuracy: 0.7404 Epoch 21/50 176/176 [==============================] - 14s 77ms/step - loss: 2.0856 - accuracy: 0.4442 - top-5-accuracy: 0.7564 - val_loss: 2.1025 - val_accuracy: 0.4512 - val_top-5-accuracy: 0.7448 Epoch 22/50 176/176 [==============================] - 14s 77ms/step - loss: 2.0320 - accuracy: 0.4566 - top-5-accuracy: 0.7668 - val_loss: 2.0677 - val_accuracy: 0.4600 - val_top-5-accuracy: 0.7534 Epoch 23/50 176/176 [==============================] - 14s 77ms/step - loss: 1.9903 - accuracy: 0.4666 - top-5-accuracy: 0.7761 - val_loss: 2.0273 - val_accuracy: 0.4650 - val_top-5-accuracy: 0.7610 Epoch 24/50 176/176 [==============================] - 14s 77ms/step - loss: 1.9398 - accuracy: 0.4772 - top-5-accuracy: 0.7877 - val_loss: 2.0253 - val_accuracy: 0.4694 - val_top-5-accuracy: 0.7636 Epoch 25/50 176/176 [==============================] - 14s 78ms/step - loss: 1.9027 - accuracy: 0.4865 - top-5-accuracy: 0.7933 - val_loss: 2.0584 - val_accuracy: 0.4606 - val_top-5-accuracy: 0.7520 Epoch 26/50 176/176 [==============================] - 14s 77ms/step - loss: 1.8529 - accuracy: 0.4964 - top-5-accuracy: 0.8010 - val_loss: 2.0128 - val_accuracy: 0.4752 - val_top-5-accuracy: 0.7654 Epoch 27/50 176/176 [==============================] - 14s 77ms/step - loss: 1.8161 - accuracy: 0.5047 - top-5-accuracy: 0.8111 - val_loss: 1.9630 - val_accuracy: 0.4898 - val_top-5-accuracy: 0.7746 Epoch 28/50 176/176 [==============================] - 13s 77ms/step - loss: 1.7792 - accuracy: 0.5136 - top-5-accuracy: 0.8140 - val_loss: 1.9931 - val_accuracy: 0.4780 - val_top-5-accuracy: 0.7640 Epoch 29/50 176/176 [==============================] - 14s 77ms/step - loss: 1.7268 - accuracy: 0.5211 - top-5-accuracy: 0.8250 - val_loss: 1.9748 - val_accuracy: 0.4854 - val_top-5-accuracy: 0.7708 Epoch 30/50 176/176 [==============================] - 14s 77ms/step - loss: 1.7115 - accuracy: 0.5298 - top-5-accuracy: 0.8265 - val_loss: 1.9669 - val_accuracy: 0.4884 - val_top-5-accuracy: 0.7796 Epoch 31/50 176/176 [==============================] - 14s 77ms/step - loss: 1.6795 - accuracy: 0.5361 - top-5-accuracy: 0.8329 - val_loss: 1.9428 - val_accuracy: 0.4972 - val_top-5-accuracy: 0.7852 Epoch 32/50 176/176 [==============================] - 14s 77ms/step - loss: 1.6411 - accuracy: 0.5448 - top-5-accuracy: 0.8412 - val_loss: 1.9318 - val_accuracy: 0.4952 - val_top-5-accuracy: 0.7864 Epoch 33/50 176/176 [==============================] - 14s 77ms/step - loss: 1.6015 - accuracy: 0.5547 - top-5-accuracy: 0.8466 - val_loss: 1.9233 - val_accuracy: 0.4996 - val_top-5-accuracy: 0.7882 Epoch 34/50 176/176 [==============================] - 14s 77ms/step - loss: 1.5651 - accuracy: 0.5655 - top-5-accuracy: 0.8525 - val_loss: 1.9285 - val_accuracy: 0.5082 - val_top-5-accuracy: 0.7888 Epoch 35/50 176/176 [==============================] - 14s 77ms/step - loss: 1.5437 - accuracy: 0.5672 - top-5-accuracy: 0.8570 - val_loss: 1.9268 - val_accuracy: 0.5028 - val_top-5-accuracy: 0.7842 Epoch 36/50 176/176 [==============================] - 14s 77ms/step - loss: 1.5103 - accuracy: 0.5748 - top-5-accuracy: 0.8620 - val_loss: 1.9262 - val_accuracy: 0.5014 - val_top-5-accuracy: 0.7890 Epoch 37/50 176/176 [==============================] - 14s 77ms/step - loss: 1.4784 - accuracy: 0.5822 - top-5-accuracy: 0.8690 - val_loss: 1.8698 - val_accuracy: 0.5130 - val_top-5-accuracy: 0.7948 Epoch 38/50 176/176 [==============================] - 14s 77ms/step - loss: 1.4449 - accuracy: 0.5922 - top-5-accuracy: 0.8728 - val_loss: 1.8734 - val_accuracy: 0.5136 - val_top-5-accuracy: 0.7980 Epoch 39/50 176/176 [==============================] - 14s 77ms/step - loss: 1.4312 - accuracy: 0.5928 - top-5-accuracy: 0.8755 - val_loss: 1.8736 - val_accuracy: 0.5150 - val_top-5-accuracy: 0.7956 Epoch 40/50 176/176 [==============================] - 14s 77ms/step - loss: 1.3996 - accuracy: 0.5999 - top-5-accuracy: 0.8808 - val_loss: 1.8718 - val_accuracy: 0.5178 - val_top-5-accuracy: 0.7970 Epoch 41/50 176/176 [==============================] - 14s 77ms/step - loss: 1.3859 - accuracy: 0.6075 - top-5-accuracy: 0.8817 - val_loss: 1.9097 - val_accuracy: 0.5084 - val_top-5-accuracy: 0.7884 Epoch 42/50 176/176 [==============================] - 14s 77ms/step - loss: 1.3586 - accuracy: 0.6119 - top-5-accuracy: 0.8860 - val_loss: 1.8620 - val_accuracy: 0.5148 - val_top-5-accuracy: 0.8010 Epoch 43/50 176/176 [==============================] - 14s 77ms/step - loss: 1.3384 - accuracy: 0.6154 - top-5-accuracy: 0.8911 - val_loss: 1.8509 - val_accuracy: 0.5202 - val_top-5-accuracy: 0.8014 Epoch 44/50 176/176 [==============================] - 14s 78ms/step - loss: 1.3090 - accuracy: 0.6236 - top-5-accuracy: 0.8954 - val_loss: 1.8607 - val_accuracy: 0.5242 - val_top-5-accuracy: 0.8020 Epoch 45/50 176/176 [==============================] - 14s 78ms/step - loss: 1.2873 - accuracy: 0.6292 - top-5-accuracy: 0.8964 - val_loss: 1.8729 - val_accuracy: 0.5208 - val_top-5-accuracy: 0.8056 Epoch 46/50 176/176 [==============================] - 14s 77ms/step - loss: 1.2658 - accuracy: 0.6367 - top-5-accuracy: 0.9007 - val_loss: 1.8573 - val_accuracy: 0.5278 - val_top-5-accuracy: 0.8066 Epoch 47/50 176/176 [==============================] - 14s 77ms/step - loss: 1.2628 - accuracy: 0.6346 - top-5-accuracy: 0.9023 - val_loss: 1.8240 - val_accuracy: 0.5292 - val_top-5-accuracy: 0.8112 Epoch 48/50 176/176 [==============================] - 14s 78ms/step - loss: 1.2396 - accuracy: 0.6431 - top-5-accuracy: 0.9057 - val_loss: 1.8342 - val_accuracy: 0.5362 - val_top-5-accuracy: 0.8096 Epoch 49/50 176/176 [==============================] - 14s 77ms/step - loss: 1.2163 - accuracy: 0.6464 - top-5-accuracy: 0.9081 - val_loss: 1.8836 - val_accuracy: 0.5246 - val_top-5-accuracy: 0.8044 Epoch 50/50 176/176 [==============================] - 14s 77ms/step - loss: 1.1919 - accuracy: 0.6541 - top-5-accuracy: 0.9122 - val_loss: 1.8513 - val_accuracy: 0.5336 - val_top-5-accuracy: 0.8048 40/40 [==============================] - 1s 26ms/step - loss: 1.8172 - accuracy: 0.5310 - top-5-accuracy: 0.8053 Test accuracy: 53.1% Test top 5 accuracy: 80.53% Epoch 1/50 176/176 [==============================] - 23s 90ms/step - loss: 4.4889 - accuracy: 0.0450 - top-5-accuracy: 0.1559 - val_loss: 3.9364 - val_accuracy: 0.1128 - val_top-5-accuracy: 0.3184 Epoch 2/50 176/176 [==============================] - 15s 85ms/step - loss: 3.9806 - accuracy: 0.0924 - top-5-accuracy: 0.2798 - val_loss: 3.6392 - val_accuracy: 0.1576 - val_top-5-accuracy: 0.4034 Epoch 3/50 176/176 [==============================] - 15s 84ms/step - loss: 3.7713 - accuracy: 0.1253 - top-5-accuracy: 0.3448 - val_loss: 3.3892 - val_accuracy: 0.1918 - val_top-5-accuracy: 0.4622 Epoch 4/50 176/176 [==============================] - 15s 85ms/step - loss: 3.6297 - accuracy: 0.1460 - top-5-accuracy: 0.3859 - val_loss: 3.2856 - val_accuracy: 0.2194 - val_top-5-accuracy: 0.4970 Epoch 5/50 176/176 [==============================] - 15s 85ms/step - loss: 3.4955 - accuracy: 0.1706 - top-5-accuracy: 0.4239 - val_loss: 3.1359 - val_accuracy: 0.2412 - val_top-5-accuracy: 0.5308 Epoch 6/50 176/176 [==============================] - 15s 85ms/step - loss: 3.3781 - accuracy: 0.1908 - top-5-accuracy: 0.4565 - val_loss: 3.0535 - val_accuracy: 0.2620 - val_top-5-accuracy: 0.5652 Epoch 7/50 176/176 [==============================] - 15s 85ms/step - loss: 3.2540 - accuracy: 0.2123 - top-5-accuracy: 0.4895 - val_loss: 2.9165 - val_accuracy: 0.2782 - val_top-5-accuracy: 0.5800 Epoch 8/50 176/176 [==============================] - 15s 85ms/step - loss: 3.1442 - accuracy: 0.2318 - top-5-accuracy: 0.5197 - val_loss: 2.8592 - val_accuracy: 0.2984 - val_top-5-accuracy: 0.6090 Epoch 9/50 176/176 [==============================] - 15s 85ms/step - loss: 3.0348 - accuracy: 0.2504 - top-5-accuracy: 0.5440 - val_loss: 2.7378 - val_accuracy: 0.3146 - val_top-5-accuracy: 0.6294 Epoch 10/50 176/176 [==============================] - 15s 84ms/step - loss: 2.9311 - accuracy: 0.2681 - top-5-accuracy: 0.5704 - val_loss: 2.6274 - val_accuracy: 0.3362 - val_top-5-accuracy: 0.6446 Epoch 11/50 176/176 [==============================] - 15s 85ms/step - loss: 2.8214 - accuracy: 0.2925 - top-5-accuracy: 0.5986 - val_loss: 2.5557 - val_accuracy: 0.3458 - val_top-5-accuracy: 0.6616 Epoch 12/50 176/176 [==============================] - 15s 85ms/step - loss: 2.7244 - accuracy: 0.3100 - top-5-accuracy: 0.6168 - val_loss: 2.4763 - val_accuracy: 0.3564 - val_top-5-accuracy: 0.6804 Epoch 13/50 176/176 [==============================] - 15s 85ms/step - loss: 2.6476 - accuracy: 0.3255 - top-5-accuracy: 0.6358 - val_loss: 2.3946 - val_accuracy: 0.3678 - val_top-5-accuracy: 0.6940 Epoch 14/50 176/176 [==============================] - 15s 85ms/step - loss: 2.5518 - accuracy: 0.3436 - top-5-accuracy: 0.6584 - val_loss: 2.3362 - val_accuracy: 0.3856 - val_top-5-accuracy: 0.7038 Epoch 15/50 176/176 [==============================] - 15s 85ms/step - loss: 2.4620 - accuracy: 0.3632 - top-5-accuracy: 0.6776 - val_loss: 2.2690 - val_accuracy: 0.4006 - val_top-5-accuracy: 0.7222 Epoch 16/50 176/176 [==============================] - 15s 85ms/step - loss: 2.4010 - accuracy: 0.3749 - top-5-accuracy: 0.6908 - val_loss: 2.1937 - val_accuracy: 0.4216 - val_top-5-accuracy: 0.7338 Epoch 17/50 176/176 [==============================] - 15s 85ms/step - loss: 2.3330 - accuracy: 0.3911 - top-5-accuracy: 0.7041 - val_loss: 2.1519 - val_accuracy: 0.4286 - val_top-5-accuracy: 0.7370 Epoch 18/50 176/176 [==============================] - 15s 85ms/step - loss: 2.2600 - accuracy: 0.4069 - top-5-accuracy: 0.7171 - val_loss: 2.1212 - val_accuracy: 0.4356 - val_top-5-accuracy: 0.7460 Epoch 19/50 176/176 [==============================] - 15s 85ms/step - loss: 2.1967 - accuracy: 0.4169 - top-5-accuracy: 0.7320 - val_loss: 2.0748 - val_accuracy: 0.4470 - val_top-5-accuracy: 0.7580 Epoch 20/50 176/176 [==============================] - 15s 85ms/step - loss: 2.1397 - accuracy: 0.4302 - top-5-accuracy: 0.7450 - val_loss: 2.1152 - val_accuracy: 0.4362 - val_top-5-accuracy: 0.7416 Epoch 21/50 176/176 [==============================] - 15s 85ms/step - loss: 2.0929 - accuracy: 0.4396 - top-5-accuracy: 0.7524 - val_loss: 2.0044 - val_accuracy: 0.4652 - val_top-5-accuracy: 0.7680 Epoch 22/50 176/176 [==============================] - 15s 85ms/step - loss: 2.0423 - accuracy: 0.4521 - top-5-accuracy: 0.7639 - val_loss: 2.0628 - val_accuracy: 0.4488 - val_top-5-accuracy: 0.7544 Epoch 23/50 176/176 [==============================] - 15s 85ms/step - loss: 1.9771 - accuracy: 0.4661 - top-5-accuracy: 0.7750 - val_loss: 1.9380 - val_accuracy: 0.4740 - val_top-5-accuracy: 0.7836 Epoch 24/50 176/176 [==============================] - 15s 84ms/step - loss: 1.9323 - accuracy: 0.4752 - top-5-accuracy: 0.7848 - val_loss: 1.9461 - val_accuracy: 0.4732 - val_top-5-accuracy: 0.7768 Epoch 25/50 176/176 [==============================] - 15s 85ms/step - loss: 1.8913 - accuracy: 0.4844 - top-5-accuracy: 0.7914 - val_loss: 1.9230 - val_accuracy: 0.4768 - val_top-5-accuracy: 0.7886 Epoch 26/50 176/176 [==============================] - 15s 84ms/step - loss: 1.8520 - accuracy: 0.4950 - top-5-accuracy: 0.7999 - val_loss: 1.9159 - val_accuracy: 0.4808 - val_top-5-accuracy: 0.7900 Epoch 27/50 176/176 [==============================] - 15s 85ms/step - loss: 1.8175 - accuracy: 0.5046 - top-5-accuracy: 0.8076 - val_loss: 1.8977 - val_accuracy: 0.4896 - val_top-5-accuracy: 0.7876 Epoch 28/50 176/176 [==============================] - 15s 85ms/step - loss: 1.7692 - accuracy: 0.5133 - top-5-accuracy: 0.8146 - val_loss: 1.8632 - val_accuracy: 0.4940 - val_top-5-accuracy: 0.7920 Epoch 29/50 176/176 [==============================] - 15s 85ms/step - loss: 1.7375 - accuracy: 0.5193 - top-5-accuracy: 0.8206 - val_loss: 1.8686 - val_accuracy: 0.4926 - val_top-5-accuracy: 0.7952 Epoch 30/50 176/176 [==============================] - 15s 85ms/step - loss: 1.6952 - accuracy: 0.5308 - top-5-accuracy: 0.8280 - val_loss: 1.8265 - val_accuracy: 0.5024 - val_top-5-accuracy: 0.7996 Epoch 31/50 176/176 [==============================] - 15s 85ms/step - loss: 1.6631 - accuracy: 0.5379 - top-5-accuracy: 0.8348 - val_loss: 1.8665 - val_accuracy: 0.4942 - val_top-5-accuracy: 0.7854 Epoch 32/50 176/176 [==============================] - 15s 85ms/step - loss: 1.6329 - accuracy: 0.5466 - top-5-accuracy: 0.8401 - val_loss: 1.8364 - val_accuracy: 0.5090 - val_top-5-accuracy: 0.7996 Epoch 33/50 176/176 [==============================] - 15s 85ms/step - loss: 1.5960 - accuracy: 0.5537 - top-5-accuracy: 0.8465 - val_loss: 1.8171 - val_accuracy: 0.5136 - val_top-5-accuracy: 0.8034 Epoch 34/50 176/176 [==============================] - 15s 85ms/step - loss: 1.5815 - accuracy: 0.5578 - top-5-accuracy: 0.8476 - val_loss: 1.8020 - val_accuracy: 0.5128 - val_top-5-accuracy: 0.8042 Epoch 35/50 176/176 [==============================] - 15s 85ms/step - loss: 1.5432 - accuracy: 0.5667 - top-5-accuracy: 0.8566 - val_loss: 1.8173 - val_accuracy: 0.5142 - val_top-5-accuracy: 0.8080 Epoch 36/50 176/176 [==============================] - 15s 85ms/step - loss: 1.5110 - accuracy: 0.5768 - top-5-accuracy: 0.8594 - val_loss: 1.8168 - val_accuracy: 0.5124 - val_top-5-accuracy: 0.8066 Epoch 37/50 176/176 [==============================] - 15s 85ms/step - loss: 1.4890 - accuracy: 0.5816 - top-5-accuracy: 0.8641 - val_loss: 1.7861 - val_accuracy: 0.5274 - val_top-5-accuracy: 0.8120 Epoch 38/50 176/176 [==============================] - 15s 85ms/step - loss: 1.4672 - accuracy: 0.5849 - top-5-accuracy: 0.8660 - val_loss: 1.7695 - val_accuracy: 0.5222 - val_top-5-accuracy: 0.8106 Epoch 39/50 176/176 [==============================] - 15s 85ms/step - loss: 1.4323 - accuracy: 0.5939 - top-5-accuracy: 0.8721 - val_loss: 1.7653 - val_accuracy: 0.5250 - val_top-5-accuracy: 0.8164 Epoch 40/50 176/176 [==============================] - 15s 85ms/step - loss: 1.4192 - accuracy: 0.5975 - top-5-accuracy: 0.8754 - val_loss: 1.7727 - val_accuracy: 0.5298 - val_top-5-accuracy: 0.8154 Epoch 41/50 176/176 [==============================] - 15s 85ms/step - loss: 1.3897 - accuracy: 0.6055 - top-5-accuracy: 0.8805 - val_loss: 1.7535 - val_accuracy: 0.5328 - val_top-5-accuracy: 0.8122 Epoch 42/50 176/176 [==============================] - 15s 85ms/step - loss: 1.3702 - accuracy: 0.6087 - top-5-accuracy: 0.8828 - val_loss: 1.7746 - val_accuracy: 0.5316 - val_top-5-accuracy: 0.8116 Epoch 43/50 176/176 [==============================] - 15s 85ms/step - loss: 1.3338 - accuracy: 0.6185 - top-5-accuracy: 0.8894 - val_loss: 1.7606 - val_accuracy: 0.5342 - val_top-5-accuracy: 0.8176 Epoch 44/50 176/176 [==============================] - 15s 85ms/step - loss: 1.3171 - accuracy: 0.6200 - top-5-accuracy: 0.8920 - val_loss: 1.7490 - val_accuracy: 0.5364 - val_top-5-accuracy: 0.8164 Epoch 45/50 176/176 [==============================] - 15s 85ms/step - loss: 1.3056 - accuracy: 0.6276 - top-5-accuracy: 0.8932 - val_loss: 1.7535 - val_accuracy: 0.5388 - val_top-5-accuracy: 0.8156 Epoch 46/50 176/176 [==============================] - 15s 85ms/step - loss: 1.2876 - accuracy: 0.6289 - top-5-accuracy: 0.8952 - val_loss: 1.7546 - val_accuracy: 0.5320 - val_top-5-accuracy: 0.8154 Epoch 47/50 176/176 [==============================] - 15s 85ms/step - loss: 1.2764 - accuracy: 0.6350 - top-5-accuracy: 0.8970 - val_loss: 1.7177 - val_accuracy: 0.5382 - val_top-5-accuracy: 0.8200 Epoch 48/50 176/176 [==============================] - 15s 85ms/step - loss: 1.2543 - accuracy: 0.6407 - top-5-accuracy: 0.9001 - val_loss: 1.7330 - val_accuracy: 0.5438 - val_top-5-accuracy: 0.8198 Epoch 49/50 176/176 [==============================] - 15s 84ms/step - loss: 1.2191 - accuracy: 0.6470 - top-5-accuracy: 0.9042 - val_loss: 1.7316 - val_accuracy: 0.5436 - val_top-5-accuracy: 0.8196 Epoch 50/50 176/176 [==============================] - 15s 85ms/step - loss: 1.2186 - accuracy: 0.6457 - top-5-accuracy: 0.9066 - val_loss: 1.7201 - val_accuracy: 0.5486 - val_top-5-accuracy: 0.8218 40/40 [==============================] - 1s 30ms/step - loss: 1.6760 - accuracy: 0.5611 - top-5-accuracy: 0.8227 Test accuracy: 56.11% Test top 5 accuracy: 82.27% </code></pre></div> </div> <h1 id="final-notes">Final Notes</h1> <p>With the help of Shifted Patch Tokenization and Locality Self Attention, we were able to get ~<strong>3-4%</strong> top-1 accuracy gains on CIFAR100.</p> <p>The ideas on Shifted Patch Tokenization and Locality Self Attention are very intuitive and easy to implement. The authors also ablates of different shifting strategies for Shifted Patch Tokenization in the supplementary of the paper.</p> <p>I would like to thank <a href="https://jarvislabs.ai/">Jarvislabs.ai</a> for generously helping with GPU credits.</p> <p>You can use the trained model hosted on <a href="https://huggingface.co/keras-io/vit_small_ds_v2">Hugging Face Hub</a> and try the demo on <a href="https://huggingface.co/spaces/keras-io/vit-small-ds">Hugging Face Spaces</a>.</p> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#train-a-vision-transformer-on-small-datasets'>Train a Vision Transformer on small datasets</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#setup'>Setup</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#prepare-the-data'>Prepare the data</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#configure-the-hyperparameters'>Configure the hyperparameters</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#use-data-augmentation'>Use data augmentation</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#implement-shifted-patch-tokenization'>Implement Shifted Patch Tokenization</a> </div> <div class='k-outline-depth-3'> <a href='#visualize-the-patches'>Visualize the patches</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#implement-the-patch-encoding-layer'>Implement the patch encoding layer</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#implement-locality-self-attention'>Implement Locality Self Attention</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#implement-the-mlp'>Implement the MLP</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#build-the-vit'>Build the ViT</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#compile-train-and-evaluate-the-mode'>Compile, train, and evaluate the mode</a> </div> <div class='k-outline-depth-1'> <a href='#final-notes'>Final Notes</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>