CINXE.COM
When Recurrence meets Transformers
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/vision/temporal_latent_bottleneck/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: When Recurrence meets Transformers"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: When Recurrence meets Transformers"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>When Recurrence meets Transformers</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink active" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink2" href="/examples/vision/image_classification_from_scratch/">Image classification from scratch</a> <a class="nav-sublink2" href="/examples/vision/mnist_convnet/">Simple MNIST convnet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_efficientnet_fine_tuning/">Image classification via fine-tuning with EfficientNet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_with_vision_transformer/">Image classification with Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/attention_mil_classification/">Classification using Attention-based Deep Multiple Instance Learning</a> <a class="nav-sublink2" href="/examples/vision/mlp_image_classification/">Image classification with modern MLP models</a> <a class="nav-sublink2" href="/examples/vision/mobilevit/">A mobile-friendly Transformer-based model for image classification</a> <a class="nav-sublink2" href="/examples/vision/xray_classification_with_tpus/">Pneumonia Classification on TPU</a> <a class="nav-sublink2" href="/examples/vision/cct/">Compact Convolutional Transformers</a> <a class="nav-sublink2" href="/examples/vision/convmixer/">Image classification with ConvMixer</a> <a class="nav-sublink2" href="/examples/vision/eanet/">Image classification with EANet (External Attention Transformer)</a> <a class="nav-sublink2" href="/examples/vision/involution/">Involutional neural networks</a> <a class="nav-sublink2" href="/examples/vision/perceiver_image_classification/">Image classification with Perceiver</a> <a class="nav-sublink2" href="/examples/vision/reptile/">Few-Shot learning with Reptile</a> <a class="nav-sublink2" href="/examples/vision/semisupervised_simclr/">Semi-supervised image classification using contrastive pretraining with SimCLR</a> <a class="nav-sublink2" href="/examples/vision/swin_transformers/">Image classification with Swin Transformers</a> <a class="nav-sublink2" href="/examples/vision/vit_small_ds/">Train a Vision Transformer on small datasets</a> <a class="nav-sublink2" href="/examples/vision/shiftvit/">A Vision Transformer without Attention</a> <a class="nav-sublink2" href="/examples/vision/image_classification_using_global_context_vision_transformer/">Image Classification using Global Context Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/oxford_pets_image_segmentation/">Image segmentation with a U-Net-like architecture</a> <a class="nav-sublink2" href="/examples/vision/deeplabv3_plus/">Multiclass semantic segmentation using DeepLabV3+</a> <a class="nav-sublink2" href="/examples/vision/basnet_segmentation/">Highly accurate boundaries segmentation using BASNet</a> <a class="nav-sublink2" href="/examples/vision/fully_convolutional_network/">Image Segmentation using Composable Fully-Convolutional Networks</a> <a class="nav-sublink2" href="/examples/vision/retinanet/">Object Detection with RetinaNet</a> <a class="nav-sublink2" href="/examples/vision/keypoint_detection/">Keypoint Detection with Transfer Learning</a> <a class="nav-sublink2" href="/examples/vision/object_detection_using_vision_transformer/">Object detection with Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/3D_image_classification/">3D image classification from CT scans</a> <a class="nav-sublink2" href="/examples/vision/depth_estimation/">Monocular depth estimation</a> <a class="nav-sublink2" href="/examples/vision/nerf/">3D volumetric rendering with NeRF</a> <a class="nav-sublink2" href="/examples/vision/pointnet_segmentation/">Point cloud segmentation with PointNet</a> <a class="nav-sublink2" href="/examples/vision/pointnet/">Point cloud classification</a> <a class="nav-sublink2" href="/examples/vision/captcha_ocr/">OCR model for reading Captchas</a> <a class="nav-sublink2" href="/examples/vision/handwriting_recognition/">Handwriting recognition</a> <a class="nav-sublink2" href="/examples/vision/autoencoder/">Convolutional autoencoder for image denoising</a> <a class="nav-sublink2" href="/examples/vision/mirnet/">Low-light image enhancement using MIRNet</a> <a class="nav-sublink2" href="/examples/vision/super_resolution_sub_pixel/">Image Super-Resolution using an Efficient Sub-Pixel CNN</a> <a class="nav-sublink2" href="/examples/vision/edsr/">Enhanced Deep Residual Networks for single-image super-resolution</a> <a class="nav-sublink2" href="/examples/vision/zero_dce/">Zero-DCE for low-light image enhancement</a> <a class="nav-sublink2" href="/examples/vision/cutmix/">CutMix data augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/mixup/">MixUp augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/randaugment/">RandAugment for Image Classification for Improved Robustness</a> <a class="nav-sublink2" href="/examples/vision/image_captioning/">Image captioning</a> <a class="nav-sublink2" href="/examples/vision/nl_image_search/">Natural language image search with a Dual Encoder</a> <a class="nav-sublink2" href="/examples/vision/visualizing_what_convnets_learn/">Visualizing what convnets learn</a> <a class="nav-sublink2" href="/examples/vision/integrated_gradients/">Model interpretability with Integrated Gradients</a> <a class="nav-sublink2" href="/examples/vision/probing_vits/">Investigating Vision Transformer representations</a> <a class="nav-sublink2" href="/examples/vision/grad_cam/">Grad-CAM class activation visualization</a> <a class="nav-sublink2" href="/examples/vision/near_dup_search/">Near-duplicate image search</a> <a class="nav-sublink2" href="/examples/vision/semantic_image_clustering/">Semantic Image Clustering</a> <a class="nav-sublink2" href="/examples/vision/siamese_contrastive/">Image similarity estimation using a Siamese Network with a contrastive loss</a> <a class="nav-sublink2" href="/examples/vision/siamese_network/">Image similarity estimation using a Siamese Network with a triplet loss</a> <a class="nav-sublink2" href="/examples/vision/metric_learning/">Metric learning for image similarity search</a> <a class="nav-sublink2" href="/examples/vision/metric_learning_tf_similarity/">Metric learning for image similarity search using TensorFlow Similarity</a> <a class="nav-sublink2" href="/examples/vision/nnclr/">Self-supervised contrastive learning with NNCLR</a> <a class="nav-sublink2" href="/examples/vision/video_classification/">Video Classification with a CNN-RNN Architecture</a> <a class="nav-sublink2" href="/examples/vision/conv_lstm/">Next-Frame Video Prediction with Convolutional LSTMs</a> <a class="nav-sublink2" href="/examples/vision/video_transformers/">Video Classification with Transformers</a> <a class="nav-sublink2" href="/examples/vision/vivit/">Video Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/bit/">Image Classification using BigTransfer (BiT)</a> <a class="nav-sublink2" href="/examples/vision/gradient_centralization/">Gradient Centralization for Better Training Performance</a> <a class="nav-sublink2" href="/examples/vision/token_learner/">Learning to tokenize in Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/knowledge_distillation/">Knowledge Distillation</a> <a class="nav-sublink2" href="/examples/vision/fixres/">FixRes: Fixing train-test resolution discrepancy</a> <a class="nav-sublink2" href="/examples/vision/cait/">Class Attention Image Transformers with LayerScale</a> <a class="nav-sublink2" href="/examples/vision/patch_convnet/">Augmenting convnets with aggregated attention</a> <a class="nav-sublink2" href="/examples/vision/learnable_resizer/">Learning to Resize</a> <a class="nav-sublink2" href="/examples/vision/adamatch/">Semi-supervision and domain adaptation with AdaMatch</a> <a class="nav-sublink2" href="/examples/vision/barlow_twins/">Barlow Twins for Contrastive SSL</a> <a class="nav-sublink2" href="/examples/vision/consistency_training/">Consistency training with supervision</a> <a class="nav-sublink2" href="/examples/vision/deit/">Distilling Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/focal_modulation_network/">Focal Modulation: A replacement for Self-Attention</a> <a class="nav-sublink2" href="/examples/vision/forwardforward/">Using the Forward-Forward Algorithm for Image Classification</a> <a class="nav-sublink2" href="/examples/vision/masked_image_modeling/">Masked image modeling with Autoencoders</a> <a class="nav-sublink2" href="/examples/vision/sam/">Segment Anything Model with 🤗Transformers</a> <a class="nav-sublink2" href="/examples/vision/segformer/">Semantic segmentation with SegFormer and Hugging Face Transformers</a> <a class="nav-sublink2" href="/examples/vision/simsiam/">Self-supervised contrastive learning with SimSiam</a> <a class="nav-sublink2" href="/examples/vision/supervised-contrastive-learning/">Supervised Contrastive Learning</a> <a class="nav-sublink2 active" href="/examples/vision/temporal_latent_bottleneck/">When Recurrence meets Transformers</a> <a class="nav-sublink2" href="/examples/vision/yolov8/">Efficient Object Detection with YOLOV8 and KerasCV</a> <a class="nav-sublink" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparam Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/vision/'>Computer Vision</a> / When Recurrence meets Transformers </div> <div class='k-content'> <h1 id="when-recurrence-meets-transformers">When Recurrence meets Transformers</h1> <p><strong>Author:</strong> <a href="https://twitter.com/ariG23498">Aritra Roy Gosthipaty</a>, <a href="https://twitter.com/halcyonrayes">Suvaditya Mukherjee</a><br> <strong>Date created:</strong> 2023/03/12<br> <strong>Last modified:</strong> 2024/11/12<br> <strong>Description:</strong> Image Classification with Temporal Latent Bottleneck Networks.</p> <div class='example_version_banner keras_2'>ⓘ This example uses Keras 2</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/vision/ipynb/temporal_latent_bottleneck.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/examples/vision/temporal_latent_bottleneck.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="introduction">Introduction</h2> <p>A simple Recurrent Neural Network (RNN) displays a strong inductive bias towards learning <strong>temporally compressed representations</strong>. <strong>Equation 1</strong> shows the recurrence formula, where <code>h_t</code> is the compressed representation (a single vector) of the entire input sequence <code>x</code>.</p> <table> <thead> <tr> <th style="text-align: center;"><img alt="Equation of RNN" src="https://i.imgur.com/Kdyj2jr.png" /></th> </tr> </thead> <tbody> <tr> <td style="text-align: center;"><strong>Equation 1</strong>: The recurrence equation. (Source: Aritra and Suvaditya)</td> </tr> </tbody> </table> <p>On the other hand, Transformers (<a href="https://arxiv.org/abs/1706.03762">Vaswani et. al</a>) have little inductive bias towards learning temporally compressed representations. Transformer has achieved SoTA results in Natural Language Processing (NLP) and Vision tasks with its pairwise attention mechanism.</p> <p>While the Transformer has the ability to <strong>attend</strong> to different sections of the input sequence, the computation of attention is quadratic in nature.</p> <p><a href="https://arxiv.org/abs/2205.14794">Didolkar et. al</a> argue that having a more compressed representation of a sequence may be beneficial for <em>generalization</em>, as it can be easily <strong>re-used</strong> and <strong>re-purposed</strong> with fewer irrelevant details. While compression is good, they also notice that too much of it can harm expressiveness.</p> <p>The authors propose a solution that divides computation into <strong>two streams</strong>. A <em>slow stream</em> that is recurrent in nature and a <em>fast stream</em> that is parameterized as a Transformer. While this method has the novelty of introducing different processing streams in order to preserve and process latent states, it has parallels drawn in other works like the <a href="https://arxiv.org/abs/2103.03206">Perceiver Mechanism (by Jaegle et. al.)</a> and <a href="https://arxiv.org/abs/2009.01719">Grounded Language Learning Fast and Slow (by Hill et. al.)</a>.</p> <p>The following example explores how we can make use of the new Temporal Latent Bottleneck mechanism to perform image classification on the CIFAR-10 dataset. We implement this model by making a custom <code>RNNCell</code> implementation in order to make a <strong>performant</strong> and <strong>vectorized</strong> design.</p> <hr /> <h2 id="setup-imports">Setup imports</h2> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">os</span> <span class="kn">import</span> <span class="nn">keras</span> <span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">layers</span><span class="p">,</span> <span class="n">ops</span><span class="p">,</span> <span class="n">mixed_precision</span> <span class="kn">from</span> <span class="nn">keras.optimizers</span> <span class="kn">import</span> <span class="n">AdamW</span> <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> <span class="kn">import</span> <span class="nn">random</span> <span class="kn">from</span> <span class="nn">matplotlib</span> <span class="kn">import</span> <span class="n">pyplot</span> <span class="k">as</span> <span class="n">plt</span> <span class="c1"># Set seed for reproducibility.</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">set_random_seed</span><span class="p">(</span><span class="mi">42</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="setting-required-configuration">Setting required configuration</h2> <p>We set a few configuration parameters that are needed within the pipeline we have designed. The current parameters are for use with the <a href="https://www.cs.toronto.edu/~kriz/cifar.html">CIFAR10 dataset</a>.</p> <p>The model also supports <code>mixed-precision</code> settings, which would quantize the model to use <code>16-bit</code> float numbers where it can, while keeping some parameters in <code>32-bit</code> as needed for numerical stability. This brings performance benefits as the footprint of the model decreases significantly while bringing speed boosts at inference-time.</p> <div class="codehilite"><pre><span></span><code><span class="n">config</span> <span class="o">=</span> <span class="p">{</span> <span class="s2">"mixed_precision"</span><span class="p">:</span> <span class="kc">True</span><span class="p">,</span> <span class="s2">"dataset"</span><span class="p">:</span> <span class="s2">"cifar10"</span><span class="p">,</span> <span class="s2">"train_slice"</span><span class="p">:</span> <span class="mi">40_000</span><span class="p">,</span> <span class="s2">"batch_size"</span><span class="p">:</span> <span class="mi">2048</span><span class="p">,</span> <span class="s2">"buffer_size"</span><span class="p">:</span> <span class="mi">2048</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span> <span class="s2">"input_shape"</span><span class="p">:</span> <span class="p">[</span><span class="mi">32</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">3</span><span class="p">],</span> <span class="s2">"image_size"</span><span class="p">:</span> <span class="mi">48</span><span class="p">,</span> <span class="s2">"num_classes"</span><span class="p">:</span> <span class="mi">10</span><span class="p">,</span> <span class="s2">"learning_rate"</span><span class="p">:</span> <span class="mf">1e-4</span><span class="p">,</span> <span class="s2">"weight_decay"</span><span class="p">:</span> <span class="mf">1e-4</span><span class="p">,</span> <span class="s2">"epochs"</span><span class="p">:</span> <span class="mi">30</span><span class="p">,</span> <span class="s2">"patch_size"</span><span class="p">:</span> <span class="mi">4</span><span class="p">,</span> <span class="s2">"embed_dim"</span><span class="p">:</span> <span class="mi">64</span><span class="p">,</span> <span class="s2">"chunk_size"</span><span class="p">:</span> <span class="mi">8</span><span class="p">,</span> <span class="s2">"r"</span><span class="p">:</span> <span class="mi">2</span><span class="p">,</span> <span class="s2">"num_layers"</span><span class="p">:</span> <span class="mi">4</span><span class="p">,</span> <span class="s2">"ffn_drop"</span><span class="p">:</span> <span class="mf">0.2</span><span class="p">,</span> <span class="s2">"attn_drop"</span><span class="p">:</span> <span class="mf">0.2</span><span class="p">,</span> <span class="s2">"num_heads"</span><span class="p">:</span> <span class="mi">1</span><span class="p">,</span> <span class="p">}</span> <span class="k">if</span> <span class="n">config</span><span class="p">[</span><span class="s2">"mixed_precision"</span><span class="p">]:</span> <span class="n">policy</span> <span class="o">=</span> <span class="n">mixed_precision</span><span class="o">.</span><span class="n">Policy</span><span class="p">(</span><span class="s2">"mixed_float16"</span><span class="p">)</span> <span class="n">mixed_precision</span><span class="o">.</span><span class="n">set_global_policy</span><span class="p">(</span><span class="n">policy</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="loading-the-cifar10-dataset">Loading the CIFAR-10 dataset</h2> <p>We are going to use the CIFAR10 dataset for running our experiments. This dataset contains a training set of <code>50,000</code> images for <code>10</code> classes with the standard image size of <code>(32, 32, 3)</code>.</p> <p>It also has a separate set of <code>10,000</code> images with similar characteristics. More information about the dataset may be found at the official site for the dataset as well as <a href="https://keras.io/api/datasets/cifar10/"><code>keras.datasets.cifar10</code></a> API reference</p> <div class="codehilite"><pre><span></span><code><span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">),</span> <span class="p">(</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">)</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">datasets</span><span class="o">.</span><span class="n">cifar10</span><span class="o">.</span><span class="n">load_data</span><span class="p">()</span> <span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">),</span> <span class="p">(</span><span class="n">x_val</span><span class="p">,</span> <span class="n">y_val</span><span class="p">)</span> <span class="o">=</span> <span class="p">(</span> <span class="p">(</span><span class="n">x_train</span><span class="p">[:</span> <span class="n">config</span><span class="p">[</span><span class="s2">"train_slice"</span><span class="p">]],</span> <span class="n">y_train</span><span class="p">[:</span> <span class="n">config</span><span class="p">[</span><span class="s2">"train_slice"</span><span class="p">]]),</span> <span class="p">(</span><span class="n">x_train</span><span class="p">[</span><span class="n">config</span><span class="p">[</span><span class="s2">"train_slice"</span><span class="p">]</span> <span class="p">:],</span> <span class="n">y_train</span><span class="p">[</span><span class="n">config</span><span class="p">[</span><span class="s2">"train_slice"</span><span class="p">]</span> <span class="p">:]),</span> <span class="p">)</span> </code></pre></div> <hr /> <h2 id="define-data-augmentation-for-the-training-and-validationtest-pipelines">Define data augmentation for the training and validation/test pipelines</h2> <p>We define separate pipelines for performing image augmentation on our data. This step is important to make the model more robust to changes, helping it generalize better. The preprocessing and augmentation steps we perform are as follows:</p> <ul> <li><code>Rescaling</code> (training, test): This step is performed to normalize all image pixel values from the <code>[0,255]</code> range to <code>[0,1)</code>. This helps in maintaining numerical stability later ahead during training.</li> </ul> <ul> <li><code>Resizing</code> (training, test): We resize the image from it's original size of (32, 32) to (52, 52). This is done to account for the Random Crop, as well as comply with the specifications of the data given in the paper.</li> </ul> <ul> <li><code>RandomCrop</code> (training): This layer randomly selects a crop/sub-region of the image with size <code>(48, 48)</code>.</li> </ul> <ul> <li><code>RandomFlip</code> (training): This layer randomly flips all the images horizontally, keeping image sizes the same.</li> </ul> <div class="codehilite"><pre><span></span><code><span class="c1"># Build the `train` augmentation pipeline.</span> <span class="n">train_augmentation</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span> <span class="p">[</span> <span class="n">layers</span><span class="o">.</span><span class="n">Rescaling</span><span class="p">(</span><span class="mi">1</span> <span class="o">/</span> <span class="mf">255.0</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"float32"</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Resizing</span><span class="p">(</span> <span class="n">config</span><span class="p">[</span><span class="s2">"input_shape"</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="mi">20</span><span class="p">,</span> <span class="n">config</span><span class="p">[</span><span class="s2">"input_shape"</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="mi">20</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"float32"</span><span class="p">,</span> <span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">RandomCrop</span><span class="p">(</span><span class="n">config</span><span class="p">[</span><span class="s2">"image_size"</span><span class="p">],</span> <span class="n">config</span><span class="p">[</span><span class="s2">"image_size"</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"float32"</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">RandomFlip</span><span class="p">(</span><span class="s2">"horizontal"</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"float32"</span><span class="p">),</span> <span class="p">],</span> <span class="n">name</span><span class="o">=</span><span class="s2">"train_data_augmentation"</span><span class="p">,</span> <span class="p">)</span> <span class="c1"># Build the `val` and `test` data pipeline.</span> <span class="n">test_augmentation</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span> <span class="p">[</span> <span class="n">layers</span><span class="o">.</span><span class="n">Rescaling</span><span class="p">(</span><span class="mi">1</span> <span class="o">/</span> <span class="mf">255.0</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"float32"</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Resizing</span><span class="p">(</span><span class="n">config</span><span class="p">[</span><span class="s2">"image_size"</span><span class="p">],</span> <span class="n">config</span><span class="p">[</span><span class="s2">"image_size"</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"float32"</span><span class="p">),</span> <span class="p">],</span> <span class="n">name</span><span class="o">=</span><span class="s2">"test_data_augmentation"</span><span class="p">,</span> <span class="p">)</span> <span class="c1"># We define functions in place of simple lambda functions to run through the</span> <span class="c1"># [`keras.Sequential`](/api/models/sequential#sequential-class)in order to solve this warning:</span> <span class="c1"># (https://github.com/tensorflow/tensorflow/issues/56089)</span> <span class="k">def</span> <span class="nf">train_map_fn</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="n">label</span><span class="p">):</span> <span class="k">return</span> <span class="n">train_augmentation</span><span class="p">(</span><span class="n">image</span><span class="p">),</span> <span class="n">label</span> <span class="k">def</span> <span class="nf">test_map_fn</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="n">label</span><span class="p">):</span> <span class="k">return</span> <span class="n">test_augmentation</span><span class="p">(</span><span class="n">image</span><span class="p">),</span> <span class="n">label</span> </code></pre></div> <hr /> <h2 id="load-dataset-into-pydataset-object">Load dataset into <code>PyDataset</code> object</h2> <ul> <li>We take the <code>np.ndarray</code> instance of the datasets and wrap a class around it, wrapping a <a href="/api/utils/python_utils#pydataset-class"><code>keras.utils.PyDataset</code></a> and apply augmentations with keras preprocessing layers.</li> </ul> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">Dataset</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">PyDataset</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> <span class="bp">self</span><span class="p">,</span> <span class="n">x_data</span><span class="p">,</span> <span class="n">y_data</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">preprocess_fn</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span> <span class="p">):</span> <span class="k">if</span> <span class="n">shuffle</span><span class="p">:</span> <span class="n">perm</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">permutation</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">x_data</span><span class="p">))</span> <span class="n">x_data</span> <span class="o">=</span> <span class="n">x_data</span><span class="p">[</span><span class="n">perm</span><span class="p">]</span> <span class="n">y_data</span> <span class="o">=</span> <span class="n">y_data</span><span class="p">[</span><span class="n">perm</span><span class="p">]</span> <span class="bp">self</span><span class="o">.</span><span class="n">x_data</span> <span class="o">=</span> <span class="n">x_data</span> <span class="bp">self</span><span class="o">.</span><span class="n">y_data</span> <span class="o">=</span> <span class="n">y_data</span> <span class="bp">self</span><span class="o">.</span><span class="n">preprocess_fn</span> <span class="o">=</span> <span class="n">preprocess_fn</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">=</span> <span class="n">batch_size</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">kwargs</span><span class="p">)</span> <span class="k">def</span> <span class="fm">__len__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="k">return</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">x_data</span><span class="p">)</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="k">def</span> <span class="fm">__getitem__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">idx</span><span class="p">):</span> <span class="n">batch_x</span><span class="p">,</span> <span class="n">batch_y</span> <span class="o">=</span> <span class="p">[],</span> <span class="p">[]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">idx</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="p">(</span><span class="n">idx</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span><span class="p">):</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">x_data</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">y_data</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">preprocess_fn</span><span class="p">:</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">preprocess_fn</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span> <span class="n">batch_x</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">batch_y</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">y</span><span class="p">)</span> <span class="n">batch_x</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">batch_x</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="n">batch_y</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">batch_y</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="k">return</span> <span class="n">batch_x</span><span class="p">,</span> <span class="n">batch_y</span> <span class="n">train_ds</span> <span class="o">=</span> <span class="n">Dataset</span><span class="p">(</span> <span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">,</span> <span class="n">config</span><span class="p">[</span><span class="s2">"batch_size"</span><span class="p">],</span> <span class="n">preprocess_fn</span><span class="o">=</span><span class="n">train_map_fn</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span> <span class="p">)</span> <span class="n">val_ds</span> <span class="o">=</span> <span class="n">Dataset</span><span class="p">(</span><span class="n">x_val</span><span class="p">,</span> <span class="n">y_val</span><span class="p">,</span> <span class="n">config</span><span class="p">[</span><span class="s2">"batch_size"</span><span class="p">],</span> <span class="n">preprocess_fn</span><span class="o">=</span><span class="n">test_map_fn</span><span class="p">)</span> <span class="n">test_ds</span> <span class="o">=</span> <span class="n">Dataset</span><span class="p">(</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">,</span> <span class="n">config</span><span class="p">[</span><span class="s2">"batch_size"</span><span class="p">],</span> <span class="n">preprocess_fn</span><span class="o">=</span><span class="n">test_map_fn</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="temporal-latent-bottleneck">Temporal Latent Bottleneck</h2> <p>An excerpt from the paper:</p> <blockquote> <p>In the brain, short-term and long-term memory have developed in a specialized way. Short-term memory is allowed to change very quickly to react to immediate sensory inputs and perception. By contrast, long-term memory changes slowly, is highly selective and involves repeated consolidation.</p> </blockquote> <p>Inspired from the short-term and long-term memory the authors introduce the fast stream and slow stream computation. The fast stream has a short-term memory with a high capacity that reacts quickly to sensory input (Transformers). The slow stream has long-term memory which updates at a slower rate and summarizes the most relevant information (Recurrence).</p> <p>To implement this idea we need to:</p> <ul> <li>Take a sequence of data.</li> <li>Divide the sequence into fixed-size chunks.</li> <li>Fast stream operates within each chunk. It provides fine-grained local information.</li> <li>Slow stream consolidates and aggregates information across chunks. It provides coarse-grained distant information.</li> </ul> <p>The fast and slow stream induce what is called <strong>information asymmetry</strong>. The two streams interact with each other through a bottleneck of attention. <strong>Figure 1</strong> shows the architecture of the model.</p> <table> <thead> <tr> <th style="text-align: center;"><img alt="Architecture of the model" src="https://i.imgur.com/bxdLPNH.png" /></th> </tr> </thead> <tbody> <tr> <td style="text-align: center;">Figure 1: Architecture of the model. (Source: https://arxiv.org/abs/2205.14794)</td> </tr> </tbody> </table> <p>A PyTorch-style pseudocode is also proposed by the authors as shown in <strong>Algorithm 1</strong>.</p> <table> <thead> <tr> <th style="text-align: center;"><img alt="Pseudocode of the model" src="https://i.imgur.com/s8a5Vz9.png" /></th> </tr> </thead> <tbody> <tr> <td style="text-align: center;">Algorithm 1: PyTorch style pseudocode. (Source: https://arxiv.org/abs/2205.14794)</td> </tr> </tbody> </table> <h3 id="patchembedding-layer"><code>PatchEmbedding</code> layer</h3> <p>This custom <a href="/api/layers/base_layer#layer-class"><code>keras.layers.Layer</code></a> is useful for generating patches from the image and transform them into a higher-dimensional embedding space using <a href="/api/layers/core_layers/embedding#embedding-class"><code>keras.layers.Embedding</code></a>. The patching operation is done using a <a href="/api/layers/convolution_layers/convolution2d#conv2d-class"><code>keras.layers.Conv2D</code></a> instance.</p> <p>Once the patching of images is complete, we reshape the image patches in order to get a flattened representation where the number of dimensions is the embedding dimension. At this stage, we also inject positional information to the tokens.</p> <p>After we obtain the tokens we chunk them. The chunking operation involves taking fixed-size sequences from the embedding output to create 'chunks', which will then be used as the final input to the model.</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">PatchEmbedding</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Image to Patch Embedding.</span> <span class="sd"> Args:</span> <span class="sd"> image_size (`Tuple[int]`): Size of the input image.</span> <span class="sd"> patch_size (`Tuple[int]`): Size of the patch.</span> <span class="sd"> embed_dim (`int`): Dimension of the embedding.</span> <span class="sd"> chunk_size (`int`): Number of patches to be chunked.</span> <span class="sd"> """</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> <span class="bp">self</span><span class="p">,</span> <span class="n">image_size</span><span class="p">,</span> <span class="n">patch_size</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">,</span> <span class="n">chunk_size</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">,</span> <span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="c1"># Compute the patch resolution.</span> <span class="n">patch_resolution</span> <span class="o">=</span> <span class="p">[</span> <span class="n">image_size</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">//</span> <span class="n">patch_size</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">image_size</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">//</span> <span class="n">patch_size</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="p">]</span> <span class="c1"># Store the parameters.</span> <span class="bp">self</span><span class="o">.</span><span class="n">image_size</span> <span class="o">=</span> <span class="n">image_size</span> <span class="bp">self</span><span class="o">.</span><span class="n">patch_size</span> <span class="o">=</span> <span class="n">patch_size</span> <span class="bp">self</span><span class="o">.</span><span class="n">embed_dim</span> <span class="o">=</span> <span class="n">embed_dim</span> <span class="bp">self</span><span class="o">.</span><span class="n">patch_resolution</span> <span class="o">=</span> <span class="n">patch_resolution</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_patches</span> <span class="o">=</span> <span class="n">patch_resolution</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">patch_resolution</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="c1"># Define the positions of the patches.</span> <span class="bp">self</span><span class="o">.</span><span class="n">positions</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">start</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">stop</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">num_patches</span><span class="p">,</span> <span class="n">step</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># Create the layers.</span> <span class="bp">self</span><span class="o">.</span><span class="n">projection</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span> <span class="n">filters</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="n">patch_size</span><span class="p">,</span> <span class="n">strides</span><span class="o">=</span><span class="n">patch_size</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"projection"</span><span class="p">,</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">flatten</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Reshape</span><span class="p">(</span> <span class="n">target_shape</span><span class="o">=</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"flatten"</span><span class="p">,</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">position_embedding</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span> <span class="n">input_dim</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">num_patches</span><span class="p">,</span> <span class="n">output_dim</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"position_embedding"</span><span class="p">,</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">layernorm</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">LayerNormalization</span><span class="p">(</span> <span class="n">epsilon</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"layernorm"</span><span class="p">,</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunking_layer</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Reshape</span><span class="p">(</span> <span class="n">target_shape</span><span class="o">=</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_patches</span> <span class="o">//</span> <span class="n">chunk_size</span><span class="p">,</span> <span class="n">chunk_size</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"chunking_layer"</span><span class="p">,</span> <span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="c1"># Project the inputs to the embedding dimension.</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">projection</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="c1"># Flatten the pathces and add position embedding.</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">flatten</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">position_embedding</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">positions</span><span class="p">)</span> <span class="c1"># Normalize the embeddings.</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">layernorm</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># Chunk the tokens.</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunking_layer</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">x</span> </code></pre></div> <h3 id="feedforwardnetwork-layer"><code>FeedForwardNetwork</code> Layer</h3> <p>This custom <a href="/api/layers/base_layer#layer-class"><code>keras.layers.Layer</code></a> instance allows us to define a generic FFN along with a dropout.</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">FeedForwardNetwork</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Feed Forward Network.</span> <span class="sd"> Args:</span> <span class="sd"> dims (`int`): Number of units in FFN.</span> <span class="sd"> dropout (`float`): Dropout probability for FFN.</span> <span class="sd"> """</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dims</span><span class="p">,</span> <span class="n">dropout</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="c1"># Create the layers.</span> <span class="bp">self</span><span class="o">.</span><span class="n">ffn</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span> <span class="p">[</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">units</span><span class="o">=</span><span class="mi">4</span> <span class="o">*</span> <span class="n">dims</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"gelu"</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">units</span><span class="o">=</span><span class="n">dims</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">rate</span><span class="o">=</span><span class="n">dropout</span><span class="p">),</span> <span class="p">],</span> <span class="n">name</span><span class="o">=</span><span class="s2">"ffn"</span><span class="p">,</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">layernorm</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">LayerNormalization</span><span class="p">(</span> <span class="n">epsilon</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"layernorm"</span><span class="p">,</span> <span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="c1"># Apply the FFN.</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">layernorm</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">inputs</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">ffn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">x</span> </code></pre></div> <h3 id="baseattention-layer"><code>BaseAttention</code> layer</h3> <p>This custom <a href="/api/layers/base_layer#layer-class"><code>keras.layers.Layer</code></a> instance is a <code>super</code>/<code>base</code> class that wraps a <a href="/api/layers/attention_layers/multi_head_attention#multiheadattention-class"><code>keras.layers.MultiHeadAttention</code></a> layer along with some other components. This gives us basic common denominator functionality for all the Attention layers/modules in our model.</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">BaseAttention</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Base Attention Module.</span> <span class="sd"> Args:</span> <span class="sd"> num_heads (`int`): Number of attention heads.</span> <span class="sd"> key_dim (`int`): Size of each attention head for key.</span> <span class="sd"> dropout (`float`): Dropout probability for attention module.</span> <span class="sd"> """</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">,</span> <span class="n">key_dim</span><span class="p">,</span> <span class="n">dropout</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">multi_head_attention</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">MultiHeadAttention</span><span class="p">(</span> <span class="n">num_heads</span><span class="o">=</span><span class="n">num_heads</span><span class="p">,</span> <span class="n">key_dim</span><span class="o">=</span><span class="n">key_dim</span><span class="p">,</span> <span class="n">dropout</span><span class="o">=</span><span class="n">dropout</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"mha"</span><span class="p">,</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">query_layernorm</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">LayerNormalization</span><span class="p">(</span> <span class="n">epsilon</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"q_layernorm"</span><span class="p">,</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">key_layernorm</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">LayerNormalization</span><span class="p">(</span> <span class="n">epsilon</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"k_layernorm"</span><span class="p">,</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">value_layernorm</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">LayerNormalization</span><span class="p">(</span> <span class="n">epsilon</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"v_layernorm"</span><span class="p">,</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">attention_scores</span> <span class="o">=</span> <span class="kc">None</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_query</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="p">):</span> <span class="c1"># Apply the attention module.</span> <span class="n">query</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">query_layernorm</span><span class="p">(</span><span class="n">input_query</span><span class="p">)</span> <span class="n">key</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">key_layernorm</span><span class="p">(</span><span class="n">key</span><span class="p">)</span> <span class="n">value</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">value_layernorm</span><span class="p">(</span><span class="n">value</span><span class="p">)</span> <span class="p">(</span><span class="n">attention_outputs</span><span class="p">,</span> <span class="n">attention_scores</span><span class="p">)</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">multi_head_attention</span><span class="p">(</span> <span class="n">query</span><span class="o">=</span><span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="o">=</span><span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="o">=</span><span class="n">value</span><span class="p">,</span> <span class="n">return_attention_scores</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="p">)</span> <span class="c1"># Save the attention scores for later visualization.</span> <span class="bp">self</span><span class="o">.</span><span class="n">attention_scores</span> <span class="o">=</span> <span class="n">attention_scores</span> <span class="c1"># Add the input to the attention output.</span> <span class="n">x</span> <span class="o">=</span> <span class="n">input_query</span> <span class="o">+</span> <span class="n">attention_outputs</span> <span class="k">return</span> <span class="n">x</span> </code></pre></div> <h3 id="attention-with-feedforwardnetwork-layer"><code>Attention</code> with <code>FeedForwardNetwork</code> layer</h3> <p>This custom <a href="/api/layers/base_layer#layer-class"><code>keras.layers.Layer</code></a> implementation combines the <code>BaseAttention</code> and <code>FeedForwardNetwork</code> components to develop one block which will be used repeatedly within the model. This module is highly customizable and flexible, allowing for changes within the internal layers.</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">AttentionWithFFN</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Attention with Feed Forward Network.</span> <span class="sd"> Args:</span> <span class="sd"> ffn_dims (`int`): Number of units in FFN.</span> <span class="sd"> ffn_dropout (`float`): Dropout probability for FFN.</span> <span class="sd"> num_heads (`int`): Number of attention heads.</span> <span class="sd"> key_dim (`int`): Size of each attention head for key.</span> <span class="sd"> attn_dropout (`float`): Dropout probability for attention module.</span> <span class="sd"> """</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> <span class="bp">self</span><span class="p">,</span> <span class="n">ffn_dims</span><span class="p">,</span> <span class="n">ffn_dropout</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">,</span> <span class="n">key_dim</span><span class="p">,</span> <span class="n">attn_dropout</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">,</span> <span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="c1"># Create the layers.</span> <span class="bp">self</span><span class="o">.</span><span class="n">fast_stream_attention</span> <span class="o">=</span> <span class="n">BaseAttention</span><span class="p">(</span> <span class="n">num_heads</span><span class="o">=</span><span class="n">num_heads</span><span class="p">,</span> <span class="n">key_dim</span><span class="o">=</span><span class="n">key_dim</span><span class="p">,</span> <span class="n">dropout</span><span class="o">=</span><span class="n">attn_dropout</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"base_attn"</span><span class="p">,</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">slow_stream_attention</span> <span class="o">=</span> <span class="n">BaseAttention</span><span class="p">(</span> <span class="n">num_heads</span><span class="o">=</span><span class="n">num_heads</span><span class="p">,</span> <span class="n">key_dim</span><span class="o">=</span><span class="n">key_dim</span><span class="p">,</span> <span class="n">dropout</span><span class="o">=</span><span class="n">attn_dropout</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"base_attn"</span><span class="p">,</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">ffn</span> <span class="o">=</span> <span class="n">FeedForwardNetwork</span><span class="p">(</span> <span class="n">dims</span><span class="o">=</span><span class="n">ffn_dims</span><span class="p">,</span> <span class="n">dropout</span><span class="o">=</span><span class="n">ffn_dropout</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"ffn"</span><span class="p">,</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">attention_scores</span> <span class="o">=</span> <span class="kc">None</span> <span class="k">def</span> <span class="nf">build</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_shape</span><span class="p">):</span> <span class="bp">self</span><span class="o">.</span><span class="n">built</span> <span class="o">=</span> <span class="kc">True</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="p">,</span> <span class="n">stream</span><span class="o">=</span><span class="s2">"fast"</span><span class="p">):</span> <span class="c1"># Apply the attention module.</span> <span class="n">attention_layer</span> <span class="o">=</span> <span class="p">{</span> <span class="s2">"fast"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">fast_stream_attention</span><span class="p">,</span> <span class="s2">"slow"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">slow_stream_attention</span><span class="p">,</span> <span class="p">}[</span><span class="n">stream</span><span class="p">]</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">query</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">==</span> <span class="mi">2</span><span class="p">:</span> <span class="n">query</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">query</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">key</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">==</span> <span class="mi">2</span><span class="p">:</span> <span class="n">key</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">value</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">==</span> <span class="mi">2</span><span class="p">:</span> <span class="n">value</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">value</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">attention_layer</span><span class="p">(</span><span class="n">query</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">value</span><span class="p">)</span> <span class="c1"># Save the attention scores for later visualization.</span> <span class="bp">self</span><span class="o">.</span><span class="n">attention_scores</span> <span class="o">=</span> <span class="n">attention_layer</span><span class="o">.</span><span class="n">attention_scores</span> <span class="c1"># Apply the FFN.</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ffn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">x</span> </code></pre></div> <h3 id="temporal-latent-bottleneck">Custom RNN Cell for <strong>Temporal Latent Bottleneck</strong> and <strong>Perceptual Module</strong></h3> <p><strong>Algorithm 1</strong> (the pseudocode) depicts recurrence with the help of for loops. Looping does make the implementation simpler, harming the training time. In this section we wrap the custom recurrence logic inside of the <code>CustomRecurrentCell</code>. This custom cell will then be wrapped with the <a href="https://keras.io/api/layers/recurrent_layers/rnn/">Keras RNN API</a> that makes the entire code vectorizable.</p> <p>This custom cell, implemented as a <a href="/api/layers/base_layer#layer-class"><code>keras.layers.Layer</code></a>, is the integral part of the logic for the model. The cell's functionality can be divided into 2 parts: - <strong>Slow Stream (Temporal Latent Bottleneck):</strong></p> <ul> <li>This module consists of a single <code>AttentionWithFFN</code> layer that parses the output of the previous Slow Stream, an intermediate hidden representation (which is the <em>latent</em> in Temporal Latent Bottleneck) as the Query, and the output of the latest Fast Stream as Key and Value. This layer can also be construed as a <em>CrossAttention</em> layer.</li> </ul> <ul> <li><strong>Fast Stream (Perceptual Module):</strong></li> </ul> <ul> <li>This module consists of intertwined <code>AttentionWithFFN</code> layers. This stream consists of <em>n</em> layers of <code>SelfAttention</code> and <code>CrossAttention</code> in a sequential manner.</li> <li>Here, some layers take the chunked input as the Query, Key and Value (Also referred to as the <em>SelfAttention</em> layer).</li> <li>The other layers take the intermediate state outputs from within the Temporal Latent Bottleneck module as the Query while using the output of the previous Self-Attention layers before it as the Key and Value.</li> </ul> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">CustomRecurrentCell</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Custom Recurrent Cell.</span> <span class="sd"> Args:</span> <span class="sd"> chunk_size (`int`): Number of tokens in a chunk.</span> <span class="sd"> r (`int`): One Cross Attention per **r** Self Attention.</span> <span class="sd"> num_layers (`int`): Number of layers.</span> <span class="sd"> ffn_dims (`int`): Number of units in FFN.</span> <span class="sd"> ffn_dropout (`float`): Dropout probability for FFN.</span> <span class="sd"> num_heads (`int`): Number of attention heads.</span> <span class="sd"> key_dim (`int`): Size of each attention head for key.</span> <span class="sd"> attn_dropout (`float`): Dropout probability for attention module.</span> <span class="sd"> """</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> <span class="bp">self</span><span class="p">,</span> <span class="n">chunk_size</span><span class="p">,</span> <span class="n">r</span><span class="p">,</span> <span class="n">num_layers</span><span class="p">,</span> <span class="n">ffn_dims</span><span class="p">,</span> <span class="n">ffn_dropout</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">,</span> <span class="n">key_dim</span><span class="p">,</span> <span class="n">attn_dropout</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">,</span> <span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="c1"># Save the arguments.</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span> <span class="o">=</span> <span class="n">chunk_size</span> <span class="bp">self</span><span class="o">.</span><span class="n">r</span> <span class="o">=</span> <span class="n">r</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_layers</span> <span class="o">=</span> <span class="n">num_layers</span> <span class="bp">self</span><span class="o">.</span><span class="n">ffn_dims</span> <span class="o">=</span> <span class="n">ffn_dims</span> <span class="bp">self</span><span class="o">.</span><span class="n">ffn_droput</span> <span class="o">=</span> <span class="n">ffn_dropout</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_heads</span> <span class="o">=</span> <span class="n">num_heads</span> <span class="bp">self</span><span class="o">.</span><span class="n">key_dim</span> <span class="o">=</span> <span class="n">key_dim</span> <span class="bp">self</span><span class="o">.</span><span class="n">attn_dropout</span> <span class="o">=</span> <span class="n">attn_dropout</span> <span class="c1"># Create state_size. This is important for</span> <span class="c1"># custom recurrence logic.</span> <span class="bp">self</span><span class="o">.</span><span class="n">state_size</span> <span class="o">=</span> <span class="n">chunk_size</span> <span class="o">*</span> <span class="n">ffn_dims</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_attention_scores</span> <span class="o">=</span> <span class="kc">False</span> <span class="bp">self</span><span class="o">.</span><span class="n">attention_scores</span> <span class="o">=</span> <span class="p">[]</span> <span class="c1"># Perceptual Module</span> <span class="n">perceptual_module</span> <span class="o">=</span> <span class="nb">list</span><span class="p">()</span> <span class="k">for</span> <span class="n">layer_idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_layers</span><span class="p">):</span> <span class="n">perceptual_module</span><span class="o">.</span><span class="n">append</span><span class="p">(</span> <span class="n">AttentionWithFFN</span><span class="p">(</span> <span class="n">ffn_dims</span><span class="o">=</span><span class="n">ffn_dims</span><span class="p">,</span> <span class="n">ffn_dropout</span><span class="o">=</span><span class="n">ffn_dropout</span><span class="p">,</span> <span class="n">num_heads</span><span class="o">=</span><span class="n">num_heads</span><span class="p">,</span> <span class="n">key_dim</span><span class="o">=</span><span class="n">key_dim</span><span class="p">,</span> <span class="n">attn_dropout</span><span class="o">=</span><span class="n">attn_dropout</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="sa">f</span><span class="s2">"pm_self_attn_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s2">"</span><span class="p">,</span> <span class="p">)</span> <span class="p">)</span> <span class="k">if</span> <span class="n">layer_idx</span> <span class="o">%</span> <span class="n">r</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span> <span class="n">perceptual_module</span><span class="o">.</span><span class="n">append</span><span class="p">(</span> <span class="n">AttentionWithFFN</span><span class="p">(</span> <span class="n">ffn_dims</span><span class="o">=</span><span class="n">ffn_dims</span><span class="p">,</span> <span class="n">ffn_dropout</span><span class="o">=</span><span class="n">ffn_dropout</span><span class="p">,</span> <span class="n">num_heads</span><span class="o">=</span><span class="n">num_heads</span><span class="p">,</span> <span class="n">key_dim</span><span class="o">=</span><span class="n">key_dim</span><span class="p">,</span> <span class="n">attn_dropout</span><span class="o">=</span><span class="n">attn_dropout</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="sa">f</span><span class="s2">"pm_cross_attn_ffn_</span><span class="si">{</span><span class="n">layer_idx</span><span class="si">}</span><span class="s2">"</span><span class="p">,</span> <span class="p">)</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">perceptual_module</span> <span class="o">=</span> <span class="n">perceptual_module</span> <span class="c1"># Temporal Latent Bottleneck Module</span> <span class="bp">self</span><span class="o">.</span><span class="n">tlb_module</span> <span class="o">=</span> <span class="n">AttentionWithFFN</span><span class="p">(</span> <span class="n">ffn_dims</span><span class="o">=</span><span class="n">ffn_dims</span><span class="p">,</span> <span class="n">ffn_dropout</span><span class="o">=</span><span class="n">ffn_dropout</span><span class="p">,</span> <span class="n">num_heads</span><span class="o">=</span><span class="n">num_heads</span><span class="p">,</span> <span class="n">key_dim</span><span class="o">=</span><span class="n">key_dim</span><span class="p">,</span> <span class="n">attn_dropout</span><span class="o">=</span><span class="n">attn_dropout</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="sa">f</span><span class="s2">"tlb_cross_attn_ffn"</span><span class="p">,</span> <span class="p">)</span> <span class="k">def</span> <span class="nf">build</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_shape</span><span class="p">):</span> <span class="bp">self</span><span class="o">.</span><span class="n">built</span> <span class="o">=</span> <span class="kc">True</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">states</span><span class="p">):</span> <span class="c1"># inputs => (batch, chunk_size, dims)</span> <span class="c1"># states => [(batch, chunk_size, units)]</span> <span class="n">slow_stream</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">states</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">ffn_dims</span><span class="p">))</span> <span class="n">fast_stream</span> <span class="o">=</span> <span class="n">inputs</span> <span class="k">for</span> <span class="n">layer_idx</span><span class="p">,</span> <span class="n">layer</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">perceptual_module</span><span class="p">):</span> <span class="n">fast_stream</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span> <span class="n">query</span><span class="o">=</span><span class="n">fast_stream</span><span class="p">,</span> <span class="n">key</span><span class="o">=</span><span class="n">fast_stream</span><span class="p">,</span> <span class="n">value</span><span class="o">=</span><span class="n">fast_stream</span><span class="p">,</span> <span class="n">stream</span><span class="o">=</span><span class="s2">"fast"</span> <span class="p">)</span> <span class="k">if</span> <span class="n">layer_idx</span> <span class="o">%</span> <span class="bp">self</span><span class="o">.</span><span class="n">r</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span> <span class="n">fast_stream</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span> <span class="n">query</span><span class="o">=</span><span class="n">fast_stream</span><span class="p">,</span> <span class="n">key</span><span class="o">=</span><span class="n">slow_stream</span><span class="p">,</span> <span class="n">value</span><span class="o">=</span><span class="n">slow_stream</span><span class="p">,</span> <span class="n">stream</span><span class="o">=</span><span class="s2">"slow"</span> <span class="p">)</span> <span class="n">slow_stream</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">tlb_module</span><span class="p">(</span> <span class="n">query</span><span class="o">=</span><span class="n">slow_stream</span><span class="p">,</span> <span class="n">key</span><span class="o">=</span><span class="n">fast_stream</span><span class="p">,</span> <span class="n">value</span><span class="o">=</span><span class="n">fast_stream</span> <span class="p">)</span> <span class="c1"># Save the attention scores for later visualization.</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_attention_scores</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">attention_scores</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">tlb_module</span><span class="o">.</span><span class="n">attention_scores</span><span class="p">)</span> <span class="k">return</span> <span class="n">fast_stream</span><span class="p">,</span> <span class="p">[</span> <span class="n">ops</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">slow_stream</span><span class="p">,</span> <span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">chunk_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">ffn_dims</span><span class="p">))</span> <span class="p">]</span> </code></pre></div> <h3 id="temporallatentbottleneckmodel-to-encapsulate-full-model"><code>TemporalLatentBottleneckModel</code> to encapsulate full model</h3> <p>Here, we just wrap the full model as to expose it for training.</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">TemporalLatentBottleneckModel</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Model Trainer.</span> <span class="sd"> Args:</span> <span class="sd"> patch_layer ([`keras.layers.Layer`](/api/layers/base_layer#layer-class)): Patching layer.</span> <span class="sd"> custom_cell ([`keras.layers.Layer`](/api/layers/base_layer#layer-class)): Custom Recurrent Cell.</span> <span class="sd"> """</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">patch_layer</span><span class="p">,</span> <span class="n">custom_cell</span><span class="p">,</span> <span class="n">unroll_loops</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">patch_layer</span> <span class="o">=</span> <span class="n">patch_layer</span> <span class="bp">self</span><span class="o">.</span><span class="n">rnn</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">RNN</span><span class="p">(</span><span class="n">custom_cell</span><span class="p">,</span> <span class="n">unroll</span><span class="o">=</span><span class="n">unroll_loops</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"rnn"</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">gap</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">GlobalAveragePooling1D</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">"gap"</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">head</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"softmax"</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"float32"</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"head"</span><span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">patch_layer</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">rnn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">gap</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">outputs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">head</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">outputs</span> </code></pre></div> <hr /> <h2 id="build-the-model">Build the model</h2> <p>To begin training, we now define the components individually and pass them as arguments to our wrapper class, which will prepare the final model for training. We define a <code>PatchEmbed</code> layer, and the <code>CustomCell</code>-based RNN.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Build the model.</span> <span class="n">patch_layer</span> <span class="o">=</span> <span class="n">PatchEmbedding</span><span class="p">(</span> <span class="n">image_size</span><span class="o">=</span><span class="p">(</span><span class="n">config</span><span class="p">[</span><span class="s2">"image_size"</span><span class="p">],</span> <span class="n">config</span><span class="p">[</span><span class="s2">"image_size"</span><span class="p">]),</span> <span class="n">patch_size</span><span class="o">=</span><span class="p">(</span><span class="n">config</span><span class="p">[</span><span class="s2">"patch_size"</span><span class="p">],</span> <span class="n">config</span><span class="p">[</span><span class="s2">"patch_size"</span><span class="p">]),</span> <span class="n">embed_dim</span><span class="o">=</span><span class="n">config</span><span class="p">[</span><span class="s2">"embed_dim"</span><span class="p">],</span> <span class="n">chunk_size</span><span class="o">=</span><span class="n">config</span><span class="p">[</span><span class="s2">"chunk_size"</span><span class="p">],</span> <span class="p">)</span> <span class="n">custom_rnn_cell</span> <span class="o">=</span> <span class="n">CustomRecurrentCell</span><span class="p">(</span> <span class="n">chunk_size</span><span class="o">=</span><span class="n">config</span><span class="p">[</span><span class="s2">"chunk_size"</span><span class="p">],</span> <span class="n">r</span><span class="o">=</span><span class="n">config</span><span class="p">[</span><span class="s2">"r"</span><span class="p">],</span> <span class="n">num_layers</span><span class="o">=</span><span class="n">config</span><span class="p">[</span><span class="s2">"num_layers"</span><span class="p">],</span> <span class="n">ffn_dims</span><span class="o">=</span><span class="n">config</span><span class="p">[</span><span class="s2">"embed_dim"</span><span class="p">],</span> <span class="n">ffn_dropout</span><span class="o">=</span><span class="n">config</span><span class="p">[</span><span class="s2">"ffn_drop"</span><span class="p">],</span> <span class="n">num_heads</span><span class="o">=</span><span class="n">config</span><span class="p">[</span><span class="s2">"num_heads"</span><span class="p">],</span> <span class="n">key_dim</span><span class="o">=</span><span class="n">config</span><span class="p">[</span><span class="s2">"embed_dim"</span><span class="p">],</span> <span class="n">attn_dropout</span><span class="o">=</span><span class="n">config</span><span class="p">[</span><span class="s2">"attn_drop"</span><span class="p">],</span> <span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">TemporalLatentBottleneckModel</span><span class="p">(</span> <span class="n">patch_layer</span><span class="o">=</span><span class="n">patch_layer</span><span class="p">,</span> <span class="n">custom_cell</span><span class="o">=</span><span class="n">custom_rnn_cell</span><span class="p">,</span> <span class="p">)</span> </code></pre></div> <hr /> <h2 id="metrics-and-callbacks">Metrics and Callbacks</h2> <p>We use the <code>AdamW</code> optimizer since it has been shown to perform very well on several benchmark tasks from an optimization perspective. It is a version of the <a href="/api/optimizers/adam#adam-class"><code>keras.optimizers.Adam</code></a> optimizer, along with Weight Decay in place.</p> <p>For a loss function, we make use of the <a href="/api/losses/probabilistic_losses#sparsecategoricalcrossentropy-class"><code>keras.losses.SparseCategoricalCrossentropy</code></a> function that makes use of simple Cross-entropy between prediction and actual logits. We also calculate accuracy on our data as a sanity-check.</p> <div class="codehilite"><pre><span></span><code><span class="n">optimizer</span> <span class="o">=</span> <span class="n">AdamW</span><span class="p">(</span> <span class="n">learning_rate</span><span class="o">=</span><span class="n">config</span><span class="p">[</span><span class="s2">"learning_rate"</span><span class="p">],</span> <span class="n">weight_decay</span><span class="o">=</span><span class="n">config</span><span class="p">[</span><span class="s2">"weight_decay"</span><span class="p">]</span> <span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">optimizer</span><span class="p">,</span> <span class="n">loss</span><span class="o">=</span><span class="s2">"sparse_categorical_crossentropy"</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="s2">"accuracy"</span><span class="p">],</span> <span class="p">)</span> </code></pre></div> <hr /> <h2 id="train-the-model-with-modelfit">Train the model with <code>model.fit()</code></h2> <p>We pass the training dataset and run training.</p> <div class="codehilite"><pre><span></span><code><span class="n">history</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span> <span class="n">train_ds</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="n">config</span><span class="p">[</span><span class="s2">"epochs"</span><span class="p">],</span> <span class="n">validation_data</span><span class="o">=</span><span class="n">val_ds</span><span class="p">,</span> <span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 1/30 19/19 ━━━━━━━━━━━━━━━━━━━━ 1270s 62s/step - accuracy: 0.1166 - loss: 3.1132 - val_accuracy: 0.1486 - val_loss: 2.2887 Epoch 2/30 19/19 ━━━━━━━━━━━━━━━━━━━━ 1152s 60s/step - accuracy: 0.1798 - loss: 2.2290 - val_accuracy: 0.2249 - val_loss: 2.1083 Epoch 3/30 19/19 ━━━━━━━━━━━━━━━━━━━━ 1150s 60s/step - accuracy: 0.2371 - loss: 2.0661 - val_accuracy: 0.2610 - val_loss: 2.0294 Epoch 4/30 19/19 ━━━━━━━━━━━━━━━━━━━━ 1150s 60s/step - accuracy: 0.2631 - loss: 1.9997 - val_accuracy: 0.2765 - val_loss: 2.0008 Epoch 5/30 19/19 ━━━━━━━━━━━━━━━━━━━━ 1151s 60s/step - accuracy: 0.2869 - loss: 1.9634 - val_accuracy: 0.2985 - val_loss: 1.9578 Epoch 6/30 19/19 ━━━━━━━━━━━━━━━━━━━━ 1151s 60s/step - accuracy: 0.3048 - loss: 1.9314 - val_accuracy: 0.3055 - val_loss: 1.9324 Epoch 7/30 19/19 ━━━━━━━━━━━━━━━━━━━━ 1152s 60s/step - accuracy: 0.3136 - loss: 1.8977 - val_accuracy: 0.3209 - val_loss: 1.9050 Epoch 8/30 19/19 ━━━━━━━━━━━━━━━━━━━━ 1151s 60s/step - accuracy: 0.3238 - loss: 1.8717 - val_accuracy: 0.3231 - val_loss: 1.8874 Epoch 9/30 19/19 ━━━━━━━━━━━━━━━━━━━━ 1152s 60s/step - accuracy: 0.3414 - loss: 1.8453 - val_accuracy: 0.3445 - val_loss: 1.8334 Epoch 10/30 19/19 ━━━━━━━━━━━━━━━━━━━━ 1152s 60s/step - accuracy: 0.3469 - loss: 1.8119 - val_accuracy: 0.3591 - val_loss: 1.8019 Epoch 11/30 19/19 ━━━━━━━━━━━━━━━━━━━━ 1151s 60s/step - accuracy: 0.3648 - loss: 1.7712 - val_accuracy: 0.3793 - val_loss: 1.7513 Epoch 12/30 19/19 ━━━━━━━━━━━━━━━━━━━━ 1146s 60s/step - accuracy: 0.3730 - loss: 1.7332 - val_accuracy: 0.3667 - val_loss: 1.7464 Epoch 13/30 19/19 ━━━━━━━━━━━━━━━━━━━━ 1148s 60s/step - accuracy: 0.3918 - loss: 1.6986 - val_accuracy: 0.3995 - val_loss: 1.6843 Epoch 14/30 19/19 ━━━━━━━━━━━━━━━━━━━━ 1147s 60s/step - accuracy: 0.3975 - loss: 1.6679 - val_accuracy: 0.4026 - val_loss: 1.6602 Epoch 15/30 19/19 ━━━━━━━━━━━━━━━━━━━━ 1146s 60s/step - accuracy: 0.4078 - loss: 1.6400 - val_accuracy: 0.3990 - val_loss: 1.6536 Epoch 16/30 19/19 ━━━━━━━━━━━━━━━━━━━━ 1146s 60s/step - accuracy: 0.4135 - loss: 1.6224 - val_accuracy: 0.4216 - val_loss: 1.6144 Epoch 17/30 19/19 ━━━━━━━━━━━━━━━━━━━━ 1147s 60s/step - accuracy: 0.4254 - loss: 1.5884 - val_accuracy: 0.4281 - val_loss: 1.5788 Epoch 18/30 19/19 ━━━━━━━━━━━━━━━━━━━━ 1146s 60s/step - accuracy: 0.4383 - loss: 1.5614 - val_accuracy: 0.4294 - val_loss: 1.5731 Epoch 19/30 19/19 ━━━━━━━━━━━━━━━━━━━━ 1146s 60s/step - accuracy: 0.4419 - loss: 1.5440 - val_accuracy: 0.4338 - val_loss: 1.5633 Epoch 20/30 19/19 ━━━━━━━━━━━━━━━━━━━━ 1146s 60s/step - accuracy: 0.4439 - loss: 1.5268 - val_accuracy: 0.4430 - val_loss: 1.5211 Epoch 21/30 19/19 ━━━━━━━━━━━━━━━━━━━━ 1147s 60s/step - accuracy: 0.4509 - loss: 1.5108 - val_accuracy: 0.4504 - val_loss: 1.5054 Epoch 22/30 19/19 ━━━━━━━━━━━━━━━━━━━━ 1146s 60s/step - accuracy: 0.4629 - loss: 1.4828 - val_accuracy: 0.4563 - val_loss: 1.4974 Epoch 23/30 19/19 ━━━━━━━━━━━━━━━━━━━━ 1145s 60s/step - accuracy: 0.4660 - loss: 1.4682 - val_accuracy: 0.4647 - val_loss: 1.4794 Epoch 24/30 19/19 ━━━━━━━━━━━━━━━━━━━━ 1146s 60s/step - accuracy: 0.4680 - loss: 1.4524 - val_accuracy: 0.4640 - val_loss: 1.4681 Epoch 25/30 19/19 ━━━━━━━━━━━━━━━━━━━━ 1145s 60s/step - accuracy: 0.4786 - loss: 1.4297 - val_accuracy: 0.4663 - val_loss: 1.4496 Epoch 26/30 19/19 ━━━━━━━━━━━━━━━━━━━━ 1146s 60s/step - accuracy: 0.4889 - loss: 1.4149 - val_accuracy: 0.4769 - val_loss: 1.4350 Epoch 27/30 19/19 ━━━━━━━━━━━━━━━━━━━━ 1146s 60s/step - accuracy: 0.4925 - loss: 1.4009 - val_accuracy: 0.4808 - val_loss: 1.4317 Epoch 28/30 19/19 ━━━━━━━━━━━━━━━━━━━━ 1145s 60s/step - accuracy: 0.4907 - loss: 1.3994 - val_accuracy: 0.4810 - val_loss: 1.4307 Epoch 29/30 19/19 ━━━━━━━━━━━━━━━━━━━━ 1146s 60s/step - accuracy: 0.5000 - loss: 1.3832 - val_accuracy: 0.4844 - val_loss: 1.3996 Epoch 30/30 19/19 ━━━━━━━━━━━━━━━━━━━━ 1146s 60s/step - accuracy: 0.5076 - loss: 1.3592 - val_accuracy: 0.4890 - val_loss: 1.3961 --- ## Visualize training metrics The `model.fit()` will return a `history` object, which stores the values of the metrics generated during the training run (but it is ephemeral and needs to be saved manually). We now display the Loss and Accuracy curves for the training and validation sets. ```python plt.plot(history.history["loss"], label="loss") plt.plot(history.history["val_loss"], label="val_loss") plt.legend() plt.show() plt.plot(history.history["accuracy"], label="accuracy") plt.plot(history.history["val_accuracy"], label="val_accuracy") plt.legend() plt.show() </code></pre></div> ![png](/img/examples/vision/temporal_latent_bottleneck/temporal_latent_bottleneck_32_0.png) ![png](/img/examples/vision/temporal_latent_bottleneck/temporal_latent_bottleneck_32_1.png) --- ## Visualize attention maps from the Temporal Latent Bottleneck Now that we have trained our model, it is time for some visualizations. The Fast Stream (Transformers) processes a chunk of tokens. The Slow Stream processes each chunk and attends to tokens that are useful for the task. In this section we visualize the attention map of the Slow Stream. This is done by extracting the attention scores from the TLB layer at each chunk's intersection and storing it within the RNN's state. This is followed by 'ballooning' it up and returning these values. <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">score_to_viz</span><span class="p">(</span><span class="n">chunk_score</span><span class="p">):</span> <span class="c1"># get the most attended token</span> <span class="n">chunk_viz</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">chunk_score</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">2</span><span class="p">)</span> <span class="c1"># get the mean across heads</span> <span class="n">chunk_viz</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">chunk_viz</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="k">return</span> <span class="n">chunk_viz</span> <span class="c1"># Get a batch of images and labels from the testing dataset</span> <span class="n">images</span><span class="p">,</span> <span class="n">labels</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="nb">iter</span><span class="p">(</span><span class="n">test_ds</span><span class="p">))</span> <span class="c1"># Create a new model instance that is executed eagerly to allow saving</span> <span class="c1"># attention scores. This also requires unrolling loops</span> <span class="n">eager_model</span> <span class="o">=</span> <span class="n">TemporalLatentBottleneckModel</span><span class="p">(</span> <span class="n">patch_layer</span><span class="o">=</span><span class="n">patch_layer</span><span class="p">,</span> <span class="n">custom_cell</span><span class="o">=</span><span class="n">custom_rnn_cell</span><span class="p">,</span> <span class="n">unroll_loops</span><span class="o">=</span><span class="kc">True</span> <span class="p">)</span> <span class="n">eager_model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">run_eagerly</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">jit_compile</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="s2">"weights.keras"</span><span class="p">)</span> <span class="n">eager_model</span><span class="o">.</span><span class="n">load_weights</span><span class="p">(</span><span class="s2">"weights.keras"</span><span class="p">)</span> <span class="c1"># Set the get_attn_scores flag to True</span> <span class="n">eager_model</span><span class="o">.</span><span class="n">rnn</span><span class="o">.</span><span class="n">cell</span><span class="o">.</span><span class="n">get_attention_scores</span> <span class="o">=</span> <span class="kc">True</span> <span class="c1"># Run the model with the testing images and grab the</span> <span class="c1"># attention scores.</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">eager_model</span><span class="p">(</span><span class="n">images</span><span class="p">)</span> <span class="n">list_chunk_scores</span> <span class="o">=</span> <span class="n">eager_model</span><span class="o">.</span><span class="n">rnn</span><span class="o">.</span><span class="n">cell</span><span class="o">.</span><span class="n">attention_scores</span> <span class="c1"># Process the attention scores in order to visualize them</span> <span class="n">num_chunks</span> <span class="o">=</span> <span class="p">(</span><span class="n">config</span><span class="p">[</span><span class="s2">"image_size"</span><span class="p">]</span> <span class="o">//</span> <span class="n">config</span><span class="p">[</span><span class="s2">"patch_size"</span><span class="p">])</span> <span class="o">**</span> <span class="mi">2</span> <span class="o">//</span> <span class="n">config</span><span class="p">[</span><span class="s2">"chunk_size"</span><span class="p">]</span> <span class="n">list_chunk_viz</span> <span class="o">=</span> <span class="p">[</span><span class="n">score_to_viz</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">list_chunk_scores</span><span class="p">[</span><span class="o">-</span><span class="n">num_chunks</span><span class="p">:]]</span> <span class="n">chunk_viz</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">concatenate</span><span class="p">(</span><span class="n">list_chunk_viz</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="n">chunk_viz</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span> <span class="n">chunk_viz</span><span class="p">,</span> <span class="p">(</span> <span class="n">config</span><span class="p">[</span><span class="s2">"batch_size"</span><span class="p">],</span> <span class="n">config</span><span class="p">[</span><span class="s2">"image_size"</span><span class="p">]</span> <span class="o">//</span> <span class="n">config</span><span class="p">[</span><span class="s2">"patch_size"</span><span class="p">],</span> <span class="n">config</span><span class="p">[</span><span class="s2">"image_size"</span><span class="p">]</span> <span class="o">//</span> <span class="n">config</span><span class="p">[</span><span class="s2">"patch_size"</span><span class="p">],</span> <span class="mi">1</span><span class="p">,</span> <span class="p">),</span> <span class="p">)</span> <span class="n">upsampled_heat_map</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">UpSampling2D</span><span class="p">(</span> <span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">),</span> <span class="n">interpolation</span><span class="o">=</span><span class="s2">"bilinear"</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"float32"</span> <span class="p">)(</span><span class="n">chunk_viz</span><span class="p">)</span> </code></pre></div> Run the following code snippet to get different images and their attention maps. <div class="codehilite"><pre><span></span><code><span class="c1"># Sample a random image</span> <span class="n">index</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">config</span><span class="p">[</span><span class="s2">"batch_size"</span><span class="p">])</span> <span class="n">orig_image</span> <span class="o">=</span> <span class="n">images</span><span class="p">[</span><span class="n">index</span><span class="p">]</span> <span class="n">overlay_image</span> <span class="o">=</span> <span class="n">upsampled_heat_map</span><span class="p">[</span><span class="n">index</span><span class="p">,</span> <span class="o">...</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span> <span class="k">if</span> <span class="n">keras</span><span class="o">.</span><span class="n">backend</span><span class="o">.</span><span class="n">backend</span><span class="p">()</span> <span class="o">==</span> <span class="s2">"torch"</span><span class="p">:</span> <span class="c1"># when using the torch backend, we are required to ensure that the</span> <span class="c1"># image is copied from the GPU</span> <span class="n">orig_image</span> <span class="o">=</span> <span class="n">orig_image</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> <span class="n">overlay_image</span> <span class="o">=</span> <span class="n">overlay_image</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> <span class="c1"># Plot the visualization</span> <span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">(</span><span class="n">nrows</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">ncols</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">5</span><span class="p">))</span> <span class="n">ax</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">orig_image</span><span class="p">)</span> <span class="n">ax</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">"Original:"</span><span class="p">)</span> <span class="n">ax</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">"off"</span><span class="p">)</span> <span class="n">image</span> <span class="o">=</span> <span class="n">ax</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">orig_image</span><span class="p">)</span> <span class="n">ax</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span> <span class="n">overlay_image</span><span class="p">,</span> <span class="n">cmap</span><span class="o">=</span><span class="s2">"inferno"</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.6</span><span class="p">,</span> <span class="n">extent</span><span class="o">=</span><span class="n">image</span><span class="o">.</span><span class="n">get_extent</span><span class="p">(),</span> <span class="p">)</span> <span class="n">ax</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">"TLB Attention:"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> </code></pre></div> ![png](/img/examples/vision/temporal_latent_bottleneck/temporal_latent_bottleneck_36_0.png) --- ## Conclusion This example has demonstrated an implementation of the Temporal Latent Bottleneck mechanism. The example highlights the use of compression and storage of historical states in the form of a Temporal Latent Bottleneck with regular updates from a Perceptual Module as an effective method to do so. In the original paper, the authors have conducted highly extensive tests around different modalities ranging from Supervised Image Classification to applications in Reinforcement Learning. While we have only displayed a method to apply this mechanism to Image Classification, it can be extended to other modalities too with minimal changes. *Note*: While building this example we did not have the official code to refer to. This means that our implementation is inspired by the paper with no claims of being a complete reproduction. For more details on the training process one can head over to [our GitHub repository](https://github.com/suvadityamuk/Temporal-Latent-Bottleneck-TF). --- ## Acknowledgement Thanks to [Aniket Didolkar](https://www.aniketdidolkar.in/) (the first author) and [Anirudh Goyal](https://anirudh9119.github.io/) (the third author) for revieweing our work. We would like to thank [PyImageSearch](https://pyimagesearch.com/) for a Colab Pro account and [JarvisLabs.ai](https://cloud.jarvislabs.ai/) for the GPU credits. </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#when-recurrence-meets-transformers'>When Recurrence meets Transformers</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#setup-imports'>Setup imports</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#setting-required-configuration'>Setting required configuration</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#loading-the-cifar10-dataset'>Loading the CIFAR-10 dataset</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#define-data-augmentation-for-the-training-and-validationtest-pipelines'>Define data augmentation for the training and validation/test pipelines</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#load-dataset-into-pydataset-object'>Load dataset into <code>PyDataset</code> object</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#temporal-latent-bottleneck'>Temporal Latent Bottleneck</a> </div> <div class='k-outline-depth-3'> <a href='#patchembedding-layer'><code>PatchEmbedding</code> layer</a> </div> <div class='k-outline-depth-3'> <a href='#feedforwardnetwork-layer'><code>FeedForwardNetwork</code> Layer</a> </div> <div class='k-outline-depth-3'> <a href='#baseattention-layer'><code>BaseAttention</code> layer</a> </div> <div class='k-outline-depth-3'> <a href='#attention-with-feedforwardnetwork-layer'><code>Attention</code> with <code>FeedForwardNetwork</code> layer</a> </div> <div class='k-outline-depth-3'> <a href='#custom-rnn-cell-for-temporal-latent-bottleneck-and-perceptual-module'>Custom RNN Cell for **Temporal Latent Bottleneck** and **Perceptual Module**</a> </div> <div class='k-outline-depth-3'> <a href='#temporallatentbottleneckmodel-to-encapsulate-full-model'><code>TemporalLatentBottleneckModel</code> to encapsulate full model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#build-the-model'>Build the model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#metrics-and-callbacks'>Metrics and Callbacks</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#train-the-model-with-modelfit'>Train the model with <code>model.fit()</code></a> </div> <div class='k-outline-depth-1'> <a href='#get-a-batch-of-images-and-labels-from-the-testing-dataset'>Get a batch of images and labels from the testing dataset</a> </div> <div class='k-outline-depth-1'> <a href='#create-a-new-model-instance-that-is-executed-eagerly-to-allow-saving'>Create a new model instance that is executed eagerly to allow saving</a> </div> <div class='k-outline-depth-1'> <a href='#attention-scores-this-also-requires-unrolling-loops'>attention scores. This also requires unrolling loops</a> </div> <div class='k-outline-depth-1'> <a href='#set-the-getattnscores-flag-to-true'>Set the get_attn_scores flag to True</a> </div> <div class='k-outline-depth-1'> <a href='#run-the-model-with-the-testing-images-and-grab-the'>Run the model with the testing images and grab the</a> </div> <div class='k-outline-depth-1'> <a href='#attention-scores'>attention scores.</a> </div> <div class='k-outline-depth-1'> <a href='#process-the-attention-scores-in-order-to-visualize-them'>Process the attention scores in order to visualize them</a> </div> <div class='k-outline-depth-1'> <a href='#sample-a-random-image'>Sample a random image</a> </div> <div class='k-outline-depth-1'> <a href='#plot-the-visualization'>Plot the visualization</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>