CINXE.COM
Image classification with EANet (External Attention Transformer)
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/vision/eanet/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Image classification with EANet (External Attention Transformer)"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Image classification with EANet (External Attention Transformer)"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Image classification with EANet (External Attention Transformer)</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink active" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink2" href="/examples/vision/image_classification_from_scratch/">Image classification from scratch</a> <a class="nav-sublink2" href="/examples/vision/mnist_convnet/">Simple MNIST convnet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_efficientnet_fine_tuning/">Image classification via fine-tuning with EfficientNet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_with_vision_transformer/">Image classification with Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/attention_mil_classification/">Classification using Attention-based Deep Multiple Instance Learning</a> <a class="nav-sublink2" href="/examples/vision/mlp_image_classification/">Image classification with modern MLP models</a> <a class="nav-sublink2" href="/examples/vision/mobilevit/">A mobile-friendly Transformer-based model for image classification</a> <a class="nav-sublink2" href="/examples/vision/xray_classification_with_tpus/">Pneumonia Classification on TPU</a> <a class="nav-sublink2" href="/examples/vision/cct/">Compact Convolutional Transformers</a> <a class="nav-sublink2" href="/examples/vision/convmixer/">Image classification with ConvMixer</a> <a class="nav-sublink2 active" href="/examples/vision/eanet/">Image classification with EANet (External Attention Transformer)</a> <a class="nav-sublink2" href="/examples/vision/involution/">Involutional neural networks</a> <a class="nav-sublink2" href="/examples/vision/perceiver_image_classification/">Image classification with Perceiver</a> <a class="nav-sublink2" href="/examples/vision/reptile/">Few-Shot learning with Reptile</a> <a class="nav-sublink2" href="/examples/vision/semisupervised_simclr/">Semi-supervised image classification using contrastive pretraining with SimCLR</a> <a class="nav-sublink2" href="/examples/vision/swin_transformers/">Image classification with Swin Transformers</a> <a class="nav-sublink2" href="/examples/vision/vit_small_ds/">Train a Vision Transformer on small datasets</a> <a class="nav-sublink2" href="/examples/vision/shiftvit/">A Vision Transformer without Attention</a> <a class="nav-sublink2" href="/examples/vision/image_classification_using_global_context_vision_transformer/">Image Classification using Global Context Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/oxford_pets_image_segmentation/">Image segmentation with a U-Net-like architecture</a> <a class="nav-sublink2" href="/examples/vision/deeplabv3_plus/">Multiclass semantic segmentation using DeepLabV3+</a> <a class="nav-sublink2" href="/examples/vision/basnet_segmentation/">Highly accurate boundaries segmentation using BASNet</a> <a class="nav-sublink2" href="/examples/vision/fully_convolutional_network/">Image Segmentation using Composable Fully-Convolutional Networks</a> <a class="nav-sublink2" href="/examples/vision/retinanet/">Object Detection with RetinaNet</a> <a class="nav-sublink2" href="/examples/vision/keypoint_detection/">Keypoint Detection with Transfer Learning</a> <a class="nav-sublink2" href="/examples/vision/object_detection_using_vision_transformer/">Object detection with Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/3D_image_classification/">3D image classification from CT scans</a> <a class="nav-sublink2" href="/examples/vision/depth_estimation/">Monocular depth estimation</a> <a class="nav-sublink2" href="/examples/vision/nerf/">3D volumetric rendering with NeRF</a> <a class="nav-sublink2" href="/examples/vision/pointnet_segmentation/">Point cloud segmentation with PointNet</a> <a class="nav-sublink2" href="/examples/vision/pointnet/">Point cloud classification</a> <a class="nav-sublink2" href="/examples/vision/captcha_ocr/">OCR model for reading Captchas</a> <a class="nav-sublink2" href="/examples/vision/handwriting_recognition/">Handwriting recognition</a> <a class="nav-sublink2" href="/examples/vision/autoencoder/">Convolutional autoencoder for image denoising</a> <a class="nav-sublink2" href="/examples/vision/mirnet/">Low-light image enhancement using MIRNet</a> <a class="nav-sublink2" href="/examples/vision/super_resolution_sub_pixel/">Image Super-Resolution using an Efficient Sub-Pixel CNN</a> <a class="nav-sublink2" href="/examples/vision/edsr/">Enhanced Deep Residual Networks for single-image super-resolution</a> <a class="nav-sublink2" href="/examples/vision/zero_dce/">Zero-DCE for low-light image enhancement</a> <a class="nav-sublink2" href="/examples/vision/cutmix/">CutMix data augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/mixup/">MixUp augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/randaugment/">RandAugment for Image Classification for Improved Robustness</a> <a class="nav-sublink2" href="/examples/vision/image_captioning/">Image captioning</a> <a class="nav-sublink2" href="/examples/vision/nl_image_search/">Natural language image search with a Dual Encoder</a> <a class="nav-sublink2" href="/examples/vision/visualizing_what_convnets_learn/">Visualizing what convnets learn</a> <a class="nav-sublink2" href="/examples/vision/integrated_gradients/">Model interpretability with Integrated Gradients</a> <a class="nav-sublink2" href="/examples/vision/probing_vits/">Investigating Vision Transformer representations</a> <a class="nav-sublink2" href="/examples/vision/grad_cam/">Grad-CAM class activation visualization</a> <a class="nav-sublink2" href="/examples/vision/near_dup_search/">Near-duplicate image search</a> <a class="nav-sublink2" href="/examples/vision/semantic_image_clustering/">Semantic Image Clustering</a> <a class="nav-sublink2" href="/examples/vision/siamese_contrastive/">Image similarity estimation using a Siamese Network with a contrastive loss</a> <a class="nav-sublink2" href="/examples/vision/siamese_network/">Image similarity estimation using a Siamese Network with a triplet loss</a> <a class="nav-sublink2" href="/examples/vision/metric_learning/">Metric learning for image similarity search</a> <a class="nav-sublink2" href="/examples/vision/metric_learning_tf_similarity/">Metric learning for image similarity search using TensorFlow Similarity</a> <a class="nav-sublink2" href="/examples/vision/nnclr/">Self-supervised contrastive learning with NNCLR</a> <a class="nav-sublink2" href="/examples/vision/video_classification/">Video Classification with a CNN-RNN Architecture</a> <a class="nav-sublink2" href="/examples/vision/conv_lstm/">Next-Frame Video Prediction with Convolutional LSTMs</a> <a class="nav-sublink2" href="/examples/vision/video_transformers/">Video Classification with Transformers</a> <a class="nav-sublink2" href="/examples/vision/vivit/">Video Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/bit/">Image Classification using BigTransfer (BiT)</a> <a class="nav-sublink2" href="/examples/vision/gradient_centralization/">Gradient Centralization for Better Training Performance</a> <a class="nav-sublink2" href="/examples/vision/token_learner/">Learning to tokenize in Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/knowledge_distillation/">Knowledge Distillation</a> <a class="nav-sublink2" href="/examples/vision/fixres/">FixRes: Fixing train-test resolution discrepancy</a> <a class="nav-sublink2" href="/examples/vision/cait/">Class Attention Image Transformers with LayerScale</a> <a class="nav-sublink2" href="/examples/vision/patch_convnet/">Augmenting convnets with aggregated attention</a> <a class="nav-sublink2" href="/examples/vision/learnable_resizer/">Learning to Resize</a> <a class="nav-sublink2" href="/examples/vision/adamatch/">Semi-supervision and domain adaptation with AdaMatch</a> <a class="nav-sublink2" href="/examples/vision/barlow_twins/">Barlow Twins for Contrastive SSL</a> <a class="nav-sublink2" href="/examples/vision/consistency_training/">Consistency training with supervision</a> <a class="nav-sublink2" href="/examples/vision/deit/">Distilling Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/focal_modulation_network/">Focal Modulation: A replacement for Self-Attention</a> <a class="nav-sublink2" href="/examples/vision/forwardforward/">Using the Forward-Forward Algorithm for Image Classification</a> <a class="nav-sublink2" href="/examples/vision/masked_image_modeling/">Masked image modeling with Autoencoders</a> <a class="nav-sublink2" href="/examples/vision/sam/">Segment Anything Model with 🤗Transformers</a> <a class="nav-sublink2" href="/examples/vision/segformer/">Semantic segmentation with SegFormer and Hugging Face Transformers</a> <a class="nav-sublink2" href="/examples/vision/simsiam/">Self-supervised contrastive learning with SimSiam</a> <a class="nav-sublink2" href="/examples/vision/supervised-contrastive-learning/">Supervised Contrastive Learning</a> <a class="nav-sublink2" href="/examples/vision/temporal_latent_bottleneck/">When Recurrence meets Transformers</a> <a class="nav-sublink2" href="/examples/vision/yolov8/">Efficient Object Detection with YOLOV8 and KerasCV</a> <a class="nav-sublink" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparameter Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> <a class="nav-link" href="/keras_cv/" role="tab" aria-selected="">KerasCV: Computer Vision Workflows</a> <a class="nav-link" href="/keras_nlp/" role="tab" aria-selected="">KerasNLP: Natural Language Workflows</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/vision/'>Computer Vision</a> / Image classification with EANet (External Attention Transformer) </div> <div class='k-content'> <h1 id="image-classification-with-eanet-external-attention-transformer">Image classification with EANet (External Attention Transformer)</h1> <p><strong>Author:</strong> <a href="https://github.com/czy00000">ZhiYong Chang</a><br> <strong>Date created:</strong> 2021/10/19<br> <strong>Last modified:</strong> 2023/07/18<br> <strong>Description:</strong> Image classification with a Transformer that leverages external attention.</p> <div class='example_version_banner keras_3'>ⓘ This example uses Keras 3</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/vision/ipynb/eanet.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/examples/vision/eanet.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="introduction">Introduction</h2> <p>This example implements the <a href="https://arxiv.org/abs/2105.02358">EANet</a> model for image classification, and demonstrates it on the CIFAR-100 dataset. EANet introduces a novel attention mechanism named <strong><em>external attention</em></strong>, based on two external, small, learnable, and shared memories, which can be implemented easily by simply using two cascaded linear layers and two normalization layers. It conveniently replaces self-attention as used in existing architectures. External attention has linear complexity, as it only implicitly considers the correlations between all samples.</p> <hr /> <h2 id="setup">Setup</h2> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">keras</span> <span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">layers</span> <span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">ops</span> <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span> </code></pre></div> <hr /> <h2 id="prepare-the-data">Prepare the data</h2> <div class="codehilite"><pre><span></span><code><span class="n">num_classes</span> <span class="o">=</span> <span class="mi">100</span> <span class="n">input_shape</span> <span class="o">=</span> <span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span> <span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">),</span> <span class="p">(</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">)</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">datasets</span><span class="o">.</span><span class="n">cifar100</span><span class="o">.</span><span class="n">load_data</span><span class="p">()</span> <span class="n">y_train</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">to_categorical</span><span class="p">(</span><span class="n">y_train</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">)</span> <span class="n">y_test</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">to_categorical</span><span class="p">(</span><span class="n">y_test</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"x_train shape: </span><span class="si">{</span><span class="n">x_train</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2"> - y_train shape: </span><span class="si">{</span><span class="n">y_train</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"x_test shape: </span><span class="si">{</span><span class="n">x_test</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2"> - y_test shape: </span><span class="si">{</span><span class="n">y_test</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Downloading data from https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz 169001437/169001437 ━━━━━━━━━━━━━━━━━━━━ 3s 0us/step x_train shape: (50000, 32, 32, 3) - y_train shape: (50000, 100) x_test shape: (10000, 32, 32, 3) - y_test shape: (10000, 100) </code></pre></div> </div> <hr /> <h2 id="configure-the-hyperparameters">Configure the hyperparameters</h2> <div class="codehilite"><pre><span></span><code><span class="n">weight_decay</span> <span class="o">=</span> <span class="mf">0.0001</span> <span class="n">learning_rate</span> <span class="o">=</span> <span class="mf">0.001</span> <span class="n">label_smoothing</span> <span class="o">=</span> <span class="mf">0.1</span> <span class="n">validation_split</span> <span class="o">=</span> <span class="mf">0.2</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="mi">128</span> <span class="n">num_epochs</span> <span class="o">=</span> <span class="mi">50</span> <span class="n">patch_size</span> <span class="o">=</span> <span class="mi">2</span> <span class="c1"># Size of the patches to be extracted from the input images.</span> <span class="n">num_patches</span> <span class="o">=</span> <span class="p">(</span><span class="n">input_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">//</span> <span class="n">patch_size</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span> <span class="c1"># Number of patch</span> <span class="n">embedding_dim</span> <span class="o">=</span> <span class="mi">64</span> <span class="c1"># Number of hidden units.</span> <span class="n">mlp_dim</span> <span class="o">=</span> <span class="mi">64</span> <span class="n">dim_coefficient</span> <span class="o">=</span> <span class="mi">4</span> <span class="n">num_heads</span> <span class="o">=</span> <span class="mi">4</span> <span class="n">attention_dropout</span> <span class="o">=</span> <span class="mf">0.2</span> <span class="n">projection_dropout</span> <span class="o">=</span> <span class="mf">0.2</span> <span class="n">num_transformer_blocks</span> <span class="o">=</span> <span class="mi">8</span> <span class="c1"># Number of repetitions of the transformer layer</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Patch size: </span><span class="si">{</span><span class="n">patch_size</span><span class="si">}</span><span class="s2"> X </span><span class="si">{</span><span class="n">patch_size</span><span class="si">}</span><span class="s2"> = </span><span class="si">{</span><span class="n">patch_size</span><span class="w"> </span><span class="o">**</span><span class="w"> </span><span class="mi">2</span><span class="si">}</span><span class="s2"> "</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Patches per image: </span><span class="si">{</span><span class="n">num_patches</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Patch size: 2 X 2 = 4 Patches per image: 256 </code></pre></div> </div> <hr /> <h2 id="use-data-augmentation">Use data augmentation</h2> <div class="codehilite"><pre><span></span><code><span class="n">data_augmentation</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span> <span class="p">[</span> <span class="n">layers</span><span class="o">.</span><span class="n">Normalization</span><span class="p">(),</span> <span class="n">layers</span><span class="o">.</span><span class="n">RandomFlip</span><span class="p">(</span><span class="s2">"horizontal"</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">RandomRotation</span><span class="p">(</span><span class="n">factor</span><span class="o">=</span><span class="mf">0.1</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">RandomContrast</span><span class="p">(</span><span class="n">factor</span><span class="o">=</span><span class="mf">0.1</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">RandomZoom</span><span class="p">(</span><span class="n">height_factor</span><span class="o">=</span><span class="mf">0.2</span><span class="p">,</span> <span class="n">width_factor</span><span class="o">=</span><span class="mf">0.2</span><span class="p">),</span> <span class="p">],</span> <span class="n">name</span><span class="o">=</span><span class="s2">"data_augmentation"</span><span class="p">,</span> <span class="p">)</span> <span class="c1"># Compute the mean and the variance of the training data for normalization.</span> <span class="n">data_augmentation</span><span class="o">.</span><span class="n">layers</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">adapt</span><span class="p">(</span><span class="n">x_train</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="implement-the-patch-extraction-and-encoding-layer">Implement the patch extraction and encoding layer</h2> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">PatchExtract</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">patch_size</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">patch_size</span> <span class="o">=</span> <span class="n">patch_size</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="n">B</span><span class="p">,</span> <span class="n">C</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">x</span><span class="p">)[</span><span class="mi">0</span><span class="p">],</span> <span class="n">ops</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">x</span><span class="p">)[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="n">x</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">extract_patches</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">patch_size</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="p">(</span><span class="n">B</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">patch_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">patch_size</span> <span class="o">*</span> <span class="n">C</span><span class="p">))</span> <span class="k">return</span> <span class="n">x</span> <span class="k">class</span> <span class="nc">PatchEmbedding</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">num_patch</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_patch</span> <span class="o">=</span> <span class="n">num_patch</span> <span class="bp">self</span><span class="o">.</span><span class="n">proj</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">embed_dim</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">pos_embed</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span><span class="n">input_dim</span><span class="o">=</span><span class="n">num_patch</span><span class="p">,</span> <span class="n">output_dim</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">patch</span><span class="p">):</span> <span class="n">pos</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">start</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">stop</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">num_patch</span><span class="p">,</span> <span class="n">step</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">proj</span><span class="p">(</span><span class="n">patch</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">pos_embed</span><span class="p">(</span><span class="n">pos</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="implement-the-external-attention-block">Implement the external attention block</h2> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">external_attention</span><span class="p">(</span> <span class="n">x</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">,</span> <span class="n">dim_coefficient</span><span class="o">=</span><span class="mi">4</span><span class="p">,</span> <span class="n">attention_dropout</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">projection_dropout</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="p">):</span> <span class="n">_</span><span class="p">,</span> <span class="n">num_patch</span><span class="p">,</span> <span class="n">channel</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span> <span class="k">assert</span> <span class="n">dim</span> <span class="o">%</span> <span class="n">num_heads</span> <span class="o">==</span> <span class="mi">0</span> <span class="n">num_heads</span> <span class="o">=</span> <span class="n">num_heads</span> <span class="o">*</span> <span class="n">dim_coefficient</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">dim</span> <span class="o">*</span> <span class="n">dim_coefficient</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># create tensor [batch_size, num_patches, num_heads, dim*dim_coefficient//num_heads]</span> <span class="n">x</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">num_patch</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">,</span> <span class="n">dim</span> <span class="o">*</span> <span class="n">dim_coefficient</span> <span class="o">//</span> <span class="n">num_heads</span><span class="p">))</span> <span class="n">x</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axes</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">])</span> <span class="c1"># a linear layer M_k</span> <span class="n">attn</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">dim</span> <span class="o">//</span> <span class="n">dim_coefficient</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># normalize attention map</span> <span class="n">attn</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Softmax</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">2</span><span class="p">)(</span><span class="n">attn</span><span class="p">)</span> <span class="c1"># dobule-normalization</span> <span class="n">attn</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Lambda</span><span class="p">(</span> <span class="k">lambda</span> <span class="n">attn</span><span class="p">:</span> <span class="n">ops</span><span class="o">.</span><span class="n">divide</span><span class="p">(</span> <span class="n">attn</span><span class="p">,</span> <span class="n">ops</span><span class="o">.</span><span class="n">convert_to_tensor</span><span class="p">(</span><span class="mf">1e-9</span><span class="p">)</span> <span class="o">+</span> <span class="n">ops</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">attn</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="kc">True</span><span class="p">),</span> <span class="p">)</span> <span class="p">)(</span><span class="n">attn</span><span class="p">)</span> <span class="n">attn</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">attention_dropout</span><span class="p">)(</span><span class="n">attn</span><span class="p">)</span> <span class="c1"># a linear layer M_v</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">dim</span> <span class="o">*</span> <span class="n">dim_coefficient</span> <span class="o">//</span> <span class="n">num_heads</span><span class="p">)(</span><span class="n">attn</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axes</span><span class="o">=</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">])</span> <span class="n">x</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">num_patch</span><span class="p">,</span> <span class="n">dim</span> <span class="o">*</span> <span class="n">dim_coefficient</span><span class="p">])</span> <span class="c1"># a linear layer to project original dim</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">dim</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">projection_dropout</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">x</span> </code></pre></div> <hr /> <h2 id="implement-the-mlp-block">Implement the MLP block</h2> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">mlp</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">embedding_dim</span><span class="p">,</span> <span class="n">mlp_dim</span><span class="p">,</span> <span class="n">drop_rate</span><span class="o">=</span><span class="mf">0.2</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">mlp_dim</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="n">ops</span><span class="o">.</span><span class="n">gelu</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">drop_rate</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">embedding_dim</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">drop_rate</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">x</span> </code></pre></div> <hr /> <h2 id="implement-the-transformer-block">Implement the Transformer block</h2> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">transformer_encoder</span><span class="p">(</span> <span class="n">x</span><span class="p">,</span> <span class="n">embedding_dim</span><span class="p">,</span> <span class="n">mlp_dim</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">,</span> <span class="n">dim_coefficient</span><span class="p">,</span> <span class="n">attention_dropout</span><span class="p">,</span> <span class="n">projection_dropout</span><span class="p">,</span> <span class="n">attention_type</span><span class="o">=</span><span class="s2">"external_attention"</span><span class="p">,</span> <span class="p">):</span> <span class="n">residual_1</span> <span class="o">=</span> <span class="n">x</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">LayerNormalization</span><span class="p">(</span><span class="n">epsilon</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="k">if</span> <span class="n">attention_type</span> <span class="o">==</span> <span class="s2">"external_attention"</span><span class="p">:</span> <span class="n">x</span> <span class="o">=</span> <span class="n">external_attention</span><span class="p">(</span> <span class="n">x</span><span class="p">,</span> <span class="n">embedding_dim</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">,</span> <span class="n">dim_coefficient</span><span class="p">,</span> <span class="n">attention_dropout</span><span class="p">,</span> <span class="n">projection_dropout</span><span class="p">,</span> <span class="p">)</span> <span class="k">elif</span> <span class="n">attention_type</span> <span class="o">==</span> <span class="s2">"self_attention"</span><span class="p">:</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">MultiHeadAttention</span><span class="p">(</span> <span class="n">num_heads</span><span class="o">=</span><span class="n">num_heads</span><span class="p">,</span> <span class="n">key_dim</span><span class="o">=</span><span class="n">embedding_dim</span><span class="p">,</span> <span class="n">dropout</span><span class="o">=</span><span class="n">attention_dropout</span><span class="p">,</span> <span class="p">)(</span><span class="n">x</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">add</span><span class="p">([</span><span class="n">x</span><span class="p">,</span> <span class="n">residual_1</span><span class="p">])</span> <span class="n">residual_2</span> <span class="o">=</span> <span class="n">x</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">LayerNormalization</span><span class="p">(</span><span class="n">epsilon</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">mlp</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">embedding_dim</span><span class="p">,</span> <span class="n">mlp_dim</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">add</span><span class="p">([</span><span class="n">x</span><span class="p">,</span> <span class="n">residual_2</span><span class="p">])</span> <span class="k">return</span> <span class="n">x</span> </code></pre></div> <hr /> <h2 id="implement-the-eanet-model">Implement the EANet model</h2> <p>The EANet model leverages external attention. The computational complexity of traditional self attention is <code>O(d * N ** 2)</code>, where <code>d</code> is the embedding size, and <code>N</code> is the number of patch. the authors find that most pixels are closely related to just a few other pixels, and an <code>N</code>-to-<code>N</code> attention matrix may be redundant. So, they propose as an alternative an external attention module where the computational complexity of external attention is <code>O(d * S * N)</code>. As <code>d</code> and <code>S</code> are hyper-parameters, the proposed algorithm is linear in the number of pixels. In fact, this is equivalent to a drop patch operation, because a lot of information contained in a patch in an image is redundant and unimportant.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">get_model</span><span class="p">(</span><span class="n">attention_type</span><span class="o">=</span><span class="s2">"external_attention"</span><span class="p">):</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="n">input_shape</span><span class="p">)</span> <span class="c1"># Image augment</span> <span class="n">x</span> <span class="o">=</span> <span class="n">data_augmentation</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="c1"># Extract patches.</span> <span class="n">x</span> <span class="o">=</span> <span class="n">PatchExtract</span><span class="p">(</span><span class="n">patch_size</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># Create patch embedding.</span> <span class="n">x</span> <span class="o">=</span> <span class="n">PatchEmbedding</span><span class="p">(</span><span class="n">num_patches</span><span class="p">,</span> <span class="n">embedding_dim</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># Create Transformer block.</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_transformer_blocks</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="n">transformer_encoder</span><span class="p">(</span> <span class="n">x</span><span class="p">,</span> <span class="n">embedding_dim</span><span class="p">,</span> <span class="n">mlp_dim</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">,</span> <span class="n">dim_coefficient</span><span class="p">,</span> <span class="n">attention_dropout</span><span class="p">,</span> <span class="n">projection_dropout</span><span class="p">,</span> <span class="n">attention_type</span><span class="p">,</span> <span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">GlobalAveragePooling1D</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">num_classes</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"softmax"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="o">=</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="o">=</span><span class="n">outputs</span><span class="p">)</span> <span class="k">return</span> <span class="n">model</span> </code></pre></div> <hr /> <h2 id="train-on-cifar100">Train on CIFAR-100</h2> <div class="codehilite"><pre><span></span><code><span class="n">model</span> <span class="o">=</span> <span class="n">get_model</span><span class="p">(</span><span class="n">attention_type</span><span class="o">=</span><span class="s2">"external_attention"</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">loss</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">CategoricalCrossentropy</span><span class="p">(</span><span class="n">label_smoothing</span><span class="o">=</span><span class="n">label_smoothing</span><span class="p">),</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">AdamW</span><span class="p">(</span> <span class="n">learning_rate</span><span class="o">=</span><span class="n">learning_rate</span><span class="p">,</span> <span class="n">weight_decay</span><span class="o">=</span><span class="n">weight_decay</span> <span class="p">),</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span> <span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">CategoricalAccuracy</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">"accuracy"</span><span class="p">),</span> <span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">TopKCategoricalAccuracy</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"top-5-accuracy"</span><span class="p">),</span> <span class="p">],</span> <span class="p">)</span> <span class="n">history</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span> <span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="n">num_epochs</span><span class="p">,</span> <span class="n">validation_split</span><span class="o">=</span><span class="n">validation_split</span><span class="p">,</span> <span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 1/50 313/313 ━━━━━━━━━━━━━━━━━━━━ 56s 101ms/step - accuracy: 0.0367 - loss: 4.5081 - top-5-accuracy: 0.1369 - val_accuracy: 0.0659 - val_loss: 4.5736 - val_top-5-accuracy: 0.2277 Epoch 2/50 313/313 ━━━━━━━━━━━━━━━━━━━━ 31s 97ms/step - accuracy: 0.0970 - loss: 4.0453 - top-5-accuracy: 0.2965 - val_accuracy: 0.0624 - val_loss: 5.2273 - val_top-5-accuracy: 0.2178 Epoch 3/50 313/313 ━━━━━━━━━━━━━━━━━━━━ 30s 96ms/step - accuracy: 0.1287 - loss: 3.8706 - top-5-accuracy: 0.3621 - val_accuracy: 0.0690 - val_loss: 5.9141 - val_top-5-accuracy: 0.2342 Epoch 4/50 313/313 ━━━━━━━━━━━━━━━━━━━━ 30s 97ms/step - accuracy: 0.1569 - loss: 3.7600 - top-5-accuracy: 0.4071 - val_accuracy: 0.0806 - val_loss: 5.7599 - val_top-5-accuracy: 0.2510 Epoch 5/50 313/313 ━━━━━━━━━━━━━━━━━━━━ 30s 96ms/step - accuracy: 0.1839 - loss: 3.6534 - top-5-accuracy: 0.4437 - val_accuracy: 0.0954 - val_loss: 5.6725 - val_top-5-accuracy: 0.2772 Epoch 6/50 313/313 ━━━━━━━━━━━━━━━━━━━━ 30s 97ms/step - accuracy: 0.1983 - loss: 3.5784 - top-5-accuracy: 0.4643 - val_accuracy: 0.1050 - val_loss: 5.5299 - val_top-5-accuracy: 0.2898 Epoch 7/50 313/313 ━━━━━━━━━━━━━━━━━━━━ 30s 96ms/step - accuracy: 0.2142 - loss: 3.5126 - top-5-accuracy: 0.4879 - val_accuracy: 0.1108 - val_loss: 5.5076 - val_top-5-accuracy: 0.2995 Epoch 8/50 313/313 ━━━━━━━━━━━━━━━━━━━━ 31s 98ms/step - accuracy: 0.2277 - loss: 3.4624 - top-5-accuracy: 0.5044 - val_accuracy: 0.1157 - val_loss: 5.3608 - val_top-5-accuracy: 0.3065 Epoch 9/50 313/313 ━━━━━━━━━━━━━━━━━━━━ 30s 96ms/step - accuracy: 0.2360 - loss: 3.4188 - top-5-accuracy: 0.5191 - val_accuracy: 0.1200 - val_loss: 5.4690 - val_top-5-accuracy: 0.3106 Epoch 10/50 313/313 ━━━━━━━━━━━━━━━━━━━━ 30s 97ms/step - accuracy: 0.2444 - loss: 3.3684 - top-5-accuracy: 0.5387 - val_accuracy: 0.1286 - val_loss: 5.1677 - val_top-5-accuracy: 0.3263 Epoch 11/50 313/313 ━━━━━━━━━━━━━━━━━━━━ 30s 96ms/step - accuracy: 0.2532 - loss: 3.3380 - top-5-accuracy: 0.5425 - val_accuracy: 0.1161 - val_loss: 5.5990 - val_top-5-accuracy: 0.3166 Epoch 12/50 313/313 ━━━━━━━━━━━━━━━━━━━━ 30s 96ms/step - accuracy: 0.2646 - loss: 3.2978 - top-5-accuracy: 0.5537 - val_accuracy: 0.1244 - val_loss: 5.5238 - val_top-5-accuracy: 0.3181 Epoch 13/50 313/313 ━━━━━━━━━━━━━━━━━━━━ 30s 96ms/step - accuracy: 0.2722 - loss: 3.2706 - top-5-accuracy: 0.5663 - val_accuracy: 0.1304 - val_loss: 5.2244 - val_top-5-accuracy: 0.3392 Epoch 14/50 313/313 ━━━━━━━━━━━━━━━━━━━━ 30s 97ms/step - accuracy: 0.2773 - loss: 3.2406 - top-5-accuracy: 0.5707 - val_accuracy: 0.1358 - val_loss: 5.2482 - val_top-5-accuracy: 0.3431 Epoch 15/50 313/313 ━━━━━━━━━━━━━━━━━━━━ 30s 96ms/step - accuracy: 0.2839 - loss: 3.2050 - top-5-accuracy: 0.5855 - val_accuracy: 0.1288 - val_loss: 5.3406 - val_top-5-accuracy: 0.3388 Epoch 16/50 313/313 ━━━━━━━━━━━━━━━━━━━━ 30s 97ms/step - accuracy: 0.2881 - loss: 3.1856 - top-5-accuracy: 0.5918 - val_accuracy: 0.1402 - val_loss: 5.2058 - val_top-5-accuracy: 0.3502 Epoch 17/50 313/313 ━━━━━━━━━━━━━━━━━━━━ 30s 96ms/step - accuracy: 0.3006 - loss: 3.1596 - top-5-accuracy: 0.5992 - val_accuracy: 0.1410 - val_loss: 5.2260 - val_top-5-accuracy: 0.3476 Epoch 18/50 313/313 ━━━━━━━━━━━━━━━━━━━━ 30s 97ms/step - accuracy: 0.3047 - loss: 3.1334 - top-5-accuracy: 0.6068 - val_accuracy: 0.1348 - val_loss: 5.2521 - val_top-5-accuracy: 0.3415 Epoch 19/50 313/313 ━━━━━━━━━━━━━━━━━━━━ 30s 96ms/step - accuracy: 0.3058 - loss: 3.1203 - top-5-accuracy: 0.6125 - val_accuracy: 0.1433 - val_loss: 5.1966 - val_top-5-accuracy: 0.3570 Epoch 20/50 313/313 ━━━━━━━━━━━━━━━━━━━━ 30s 96ms/step - accuracy: 0.3105 - loss: 3.0968 - top-5-accuracy: 0.6141 - val_accuracy: 0.1404 - val_loss: 5.3623 - val_top-5-accuracy: 0.3497 Epoch 21/50 313/313 ━━━━━━━━━━━━━━━━━━━━ 30s 96ms/step - accuracy: 0.3161 - loss: 3.0748 - top-5-accuracy: 0.6247 - val_accuracy: 0.1486 - val_loss: 5.0754 - val_top-5-accuracy: 0.3740 Epoch 22/50 313/313 ━━━━━━━━━━━━━━━━━━━━ 31s 98ms/step - accuracy: 0.3233 - loss: 3.0536 - top-5-accuracy: 0.6288 - val_accuracy: 0.1472 - val_loss: 5.3110 - val_top-5-accuracy: 0.3545 Epoch 23/50 313/313 ━━━━━━━━━━━━━━━━━━━━ 31s 98ms/step - accuracy: 0.3281 - loss: 3.0272 - top-5-accuracy: 0.6387 - val_accuracy: 0.1408 - val_loss: 5.4392 - val_top-5-accuracy: 0.3524 Epoch 24/50 313/313 ━━━━━━━━━━━━━━━━━━━━ 31s 98ms/step - accuracy: 0.3363 - loss: 3.0089 - top-5-accuracy: 0.6389 - val_accuracy: 0.1395 - val_loss: 5.3579 - val_top-5-accuracy: 0.3555 Epoch 25/50 313/313 ━━━━━━━━━━━━━━━━━━━━ 30s 97ms/step - accuracy: 0.3386 - loss: 2.9958 - top-5-accuracy: 0.6427 - val_accuracy: 0.1550 - val_loss: 5.1783 - val_top-5-accuracy: 0.3655 Epoch 26/50 313/313 ━━━━━━━━━━━━━━━━━━━━ 31s 98ms/step - accuracy: 0.3474 - loss: 2.9824 - top-5-accuracy: 0.6496 - val_accuracy: 0.1448 - val_loss: 5.3971 - val_top-5-accuracy: 0.3596 Epoch 27/50 313/313 ━━━━━━━━━━━━━━━━━━━━ 31s 98ms/step - accuracy: 0.3500 - loss: 2.9647 - top-5-accuracy: 0.6532 - val_accuracy: 0.1519 - val_loss: 5.1895 - val_top-5-accuracy: 0.3665 Epoch 28/50 313/313 ━━━━━━━━━━━━━━━━━━━━ 31s 98ms/step - accuracy: 0.3561 - loss: 2.9414 - top-5-accuracy: 0.6604 - val_accuracy: 0.1470 - val_loss: 5.4482 - val_top-5-accuracy: 0.3600 Epoch 29/50 313/313 ━━━━━━━━━━━━━━━━━━━━ 30s 97ms/step - accuracy: 0.3572 - loss: 2.9410 - top-5-accuracy: 0.6593 - val_accuracy: 0.1572 - val_loss: 5.1866 - val_top-5-accuracy: 0.3795 Epoch 30/50 313/313 ━━━━━━━━━━━━━━━━━━━━ 31s 100ms/step - accuracy: 0.3561 - loss: 2.9263 - top-5-accuracy: 0.6670 - val_accuracy: 0.1638 - val_loss: 5.0637 - val_top-5-accuracy: 0.3934 Epoch 31/50 313/313 ━━━━━━━━━━━━━━━━━━━━ 30s 97ms/step - accuracy: 0.3621 - loss: 2.9050 - top-5-accuracy: 0.6730 - val_accuracy: 0.1589 - val_loss: 5.2504 - val_top-5-accuracy: 0.3835 Epoch 32/50 313/313 ━━━━━━━━━━━━━━━━━━━━ 30s 96ms/step - accuracy: 0.3675 - loss: 2.8898 - top-5-accuracy: 0.6754 - val_accuracy: 0.1690 - val_loss: 5.0613 - val_top-5-accuracy: 0.3950 Epoch 33/50 313/313 ━━━━━━━━━━━━━━━━━━━━ 30s 97ms/step - accuracy: 0.3771 - loss: 2.8710 - top-5-accuracy: 0.6784 - val_accuracy: 0.1596 - val_loss: 5.1941 - val_top-5-accuracy: 0.3784 Epoch 34/50 313/313 ━━━━━━━━━━━━━━━━━━━━ 30s 97ms/step - accuracy: 0.3797 - loss: 2.8536 - top-5-accuracy: 0.6880 - val_accuracy: 0.1686 - val_loss: 5.1522 - val_top-5-accuracy: 0.3879 Epoch 35/50 313/313 ━━━━━━━━━━━━━━━━━━━━ 30s 97ms/step - accuracy: 0.3792 - loss: 2.8504 - top-5-accuracy: 0.6871 - val_accuracy: 0.1525 - val_loss: 5.2875 - val_top-5-accuracy: 0.3735 Epoch 36/50 313/313 ━━━━━━━━━━━━━━━━━━━━ 30s 96ms/step - accuracy: 0.3868 - loss: 2.8278 - top-5-accuracy: 0.6950 - val_accuracy: 0.1573 - val_loss: 5.2148 - val_top-5-accuracy: 0.3797 Epoch 37/50 313/313 ━━━━━━━━━━━━━━━━━━━━ 30s 97ms/step - accuracy: 0.3869 - loss: 2.8129 - top-5-accuracy: 0.6973 - val_accuracy: 0.1562 - val_loss: 5.4344 - val_top-5-accuracy: 0.3646 Epoch 38/50 313/313 ━━━━━━━━━━━━━━━━━━━━ 30s 96ms/step - accuracy: 0.3866 - loss: 2.8129 - top-5-accuracy: 0.6977 - val_accuracy: 0.1610 - val_loss: 5.2807 - val_top-5-accuracy: 0.3772 Epoch 39/50 313/313 ━━━━━━━━━━━━━━━━━━━━ 30s 96ms/step - accuracy: 0.3934 - loss: 2.7990 - top-5-accuracy: 0.7006 - val_accuracy: 0.1681 - val_loss: 5.0741 - val_top-5-accuracy: 0.3967 Epoch 40/50 313/313 ━━━━━━━━━━━━━━━━━━━━ 30s 96ms/step - accuracy: 0.3947 - loss: 2.7863 - top-5-accuracy: 0.7065 - val_accuracy: 0.1612 - val_loss: 5.1039 - val_top-5-accuracy: 0.3885 Epoch 41/50 313/313 ━━━━━━━━━━━━━━━━━━━━ 30s 97ms/step - accuracy: 0.4030 - loss: 2.7687 - top-5-accuracy: 0.7092 - val_accuracy: 0.1592 - val_loss: 5.1138 - val_top-5-accuracy: 0.3837 Epoch 42/50 313/313 ━━━━━━━━━━━━━━━━━━━━ 30s 96ms/step - accuracy: 0.4013 - loss: 2.7706 - top-5-accuracy: 0.7071 - val_accuracy: 0.1718 - val_loss: 5.1391 - val_top-5-accuracy: 0.3938 Epoch 43/50 313/313 ━━━━━━━━━━━━━━━━━━━━ 30s 97ms/step - accuracy: 0.4062 - loss: 2.7569 - top-5-accuracy: 0.7137 - val_accuracy: 0.1593 - val_loss: 5.3004 - val_top-5-accuracy: 0.3781 Epoch 44/50 313/313 ━━━━━━━━━━━━━━━━━━━━ 30s 97ms/step - accuracy: 0.4109 - loss: 2.7429 - top-5-accuracy: 0.7129 - val_accuracy: 0.1823 - val_loss: 5.0221 - val_top-5-accuracy: 0.4038 Epoch 45/50 313/313 ━━━━━━━━━━━━━━━━━━━━ 30s 96ms/step - accuracy: 0.4074 - loss: 2.7312 - top-5-accuracy: 0.7212 - val_accuracy: 0.1706 - val_loss: 5.1799 - val_top-5-accuracy: 0.3898 Epoch 46/50 313/313 ━━━━━━━━━━━━━━━━━━━━ 30s 95ms/step - accuracy: 0.4175 - loss: 2.7121 - top-5-accuracy: 0.7202 - val_accuracy: 0.1701 - val_loss: 5.1674 - val_top-5-accuracy: 0.3910 Epoch 47/50 313/313 ━━━━━━━━━━━━━━━━━━━━ 31s 101ms/step - accuracy: 0.4187 - loss: 2.7178 - top-5-accuracy: 0.7227 - val_accuracy: 0.1764 - val_loss: 5.0161 - val_top-5-accuracy: 0.4027 Epoch 48/50 313/313 ━━━━━━━━━━━━━━━━━━━━ 30s 96ms/step - accuracy: 0.4180 - loss: 2.7045 - top-5-accuracy: 0.7246 - val_accuracy: 0.1709 - val_loss: 5.0650 - val_top-5-accuracy: 0.3907 Epoch 49/50 313/313 ━━━━━━━━━━━━━━━━━━━━ 30s 96ms/step - accuracy: 0.4264 - loss: 2.6857 - top-5-accuracy: 0.7276 - val_accuracy: 0.1591 - val_loss: 5.3416 - val_top-5-accuracy: 0.3732 Epoch 50/50 313/313 ━━━━━━━━━━━━━━━━━━━━ 30s 96ms/step - accuracy: 0.4245 - loss: 2.6878 - top-5-accuracy: 0.7271 - val_accuracy: 0.1778 - val_loss: 5.1093 - val_top-5-accuracy: 0.3987 </code></pre></div> </div> <h3 id="lets-visualize-the-training-progress-of-the-model">Let's visualize the training progress of the model.</h3> <div class="codehilite"><pre><span></span><code><span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">history</span><span class="o">.</span><span class="n">history</span><span class="p">[</span><span class="s2">"loss"</span><span class="p">],</span> <span class="n">label</span><span class="o">=</span><span class="s2">"train_loss"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">history</span><span class="o">.</span><span class="n">history</span><span class="p">[</span><span class="s2">"val_loss"</span><span class="p">],</span> <span class="n">label</span><span class="o">=</span><span class="s2">"val_loss"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s2">"Epochs"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">ylabel</span><span class="p">(</span><span class="s2">"Loss"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="s2">"Train and Validation Losses Over Epochs"</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">14</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">legend</span><span class="p">()</span> <span class="n">plt</span><span class="o">.</span><span class="n">grid</span><span class="p">()</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> </code></pre></div> <p><img alt="png" src="/img/examples/vision/eanet/eanet_24_0.png" /></p> <h3 id="lets-display-the-final-results-of-the-test-on-cifar100">Let's display the final results of the test on CIFAR-100.</h3> <div class="codehilite"><pre><span></span><code><span class="n">loss</span><span class="p">,</span> <span class="n">accuracy</span><span class="p">,</span> <span class="n">top_5_accuracy</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Test loss: </span><span class="si">{</span><span class="nb">round</span><span class="p">(</span><span class="n">loss</span><span class="p">,</span><span class="w"> </span><span class="mi">2</span><span class="p">)</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Test accuracy: </span><span class="si">{</span><span class="nb">round</span><span class="p">(</span><span class="n">accuracy</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="mi">100</span><span class="p">,</span><span class="w"> </span><span class="mi">2</span><span class="p">)</span><span class="si">}</span><span class="s2">%"</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Test top 5 accuracy: </span><span class="si">{</span><span class="nb">round</span><span class="p">(</span><span class="n">top_5_accuracy</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="mi">100</span><span class="p">,</span><span class="w"> </span><span class="mi">2</span><span class="p">)</span><span class="si">}</span><span class="s2">%"</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 313/313 ━━━━━━━━━━━━━━━━━━━━ 3s 9ms/step - accuracy: 0.1774 - loss: 5.0871 - top-5-accuracy: 0.3963 Test loss: 5.15 Test accuracy: 17.26% Test top 5 accuracy: 38.94% </code></pre></div> </div> <p>EANet just replaces self attention in Vit with external attention. The traditional Vit achieved a ~73% test top-5 accuracy and ~41 top-1 accuracy after training 50 epochs, but with 0.6M parameters. Under the same experimental environment and the same hyperparameters, The EANet model we just trained has just 0.3M parameters, and it gets us to ~73% test top-5 accuracy and ~43% top-1 accuracy. This fully demonstrates the effectiveness of external attention.</p> <p>We only show the training process of EANet, you can train Vit under the same experimental conditions and observe the test results.</p> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#image-classification-with-eanet-external-attention-transformer'>Image classification with EANet (External Attention Transformer)</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#setup'>Setup</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#prepare-the-data'>Prepare the data</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#configure-the-hyperparameters'>Configure the hyperparameters</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#use-data-augmentation'>Use data augmentation</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#implement-the-patch-extraction-and-encoding-layer'>Implement the patch extraction and encoding layer</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#implement-the-external-attention-block'>Implement the external attention block</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#implement-the-mlp-block'>Implement the MLP block</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#implement-the-transformer-block'>Implement the Transformer block</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#implement-the-eanet-model'>Implement the EANet model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#train-on-cifar100'>Train on CIFAR-100</a> </div> <div class='k-outline-depth-3'> <a href='#lets-visualize-the-training-progress-of-the-model'>Let's visualize the training progress of the model.</a> </div> <div class='k-outline-depth-3'> <a href='#lets-display-the-final-results-of-the-test-on-cifar100'>Let's display the final results of the test on CIFAR-100.</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>