CINXE.COM
Video Classification with Transformers
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/vision/video_transformers/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Video Classification with Transformers"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Video Classification with Transformers"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Video Classification with Transformers</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink active" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink2" href="/examples/vision/image_classification_from_scratch/">Image classification from scratch</a> <a class="nav-sublink2" href="/examples/vision/mnist_convnet/">Simple MNIST convnet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_efficientnet_fine_tuning/">Image classification via fine-tuning with EfficientNet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_with_vision_transformer/">Image classification with Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/attention_mil_classification/">Classification using Attention-based Deep Multiple Instance Learning</a> <a class="nav-sublink2" href="/examples/vision/mlp_image_classification/">Image classification with modern MLP models</a> <a class="nav-sublink2" href="/examples/vision/mobilevit/">A mobile-friendly Transformer-based model for image classification</a> <a class="nav-sublink2" href="/examples/vision/xray_classification_with_tpus/">Pneumonia Classification on TPU</a> <a class="nav-sublink2" href="/examples/vision/cct/">Compact Convolutional Transformers</a> <a class="nav-sublink2" href="/examples/vision/convmixer/">Image classification with ConvMixer</a> <a class="nav-sublink2" href="/examples/vision/eanet/">Image classification with EANet (External Attention Transformer)</a> <a class="nav-sublink2" href="/examples/vision/involution/">Involutional neural networks</a> <a class="nav-sublink2" href="/examples/vision/perceiver_image_classification/">Image classification with Perceiver</a> <a class="nav-sublink2" href="/examples/vision/reptile/">Few-Shot learning with Reptile</a> <a class="nav-sublink2" href="/examples/vision/semisupervised_simclr/">Semi-supervised image classification using contrastive pretraining with SimCLR</a> <a class="nav-sublink2" href="/examples/vision/swin_transformers/">Image classification with Swin Transformers</a> <a class="nav-sublink2" href="/examples/vision/vit_small_ds/">Train a Vision Transformer on small datasets</a> <a class="nav-sublink2" href="/examples/vision/shiftvit/">A Vision Transformer without Attention</a> <a class="nav-sublink2" href="/examples/vision/image_classification_using_global_context_vision_transformer/">Image Classification using Global Context Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/oxford_pets_image_segmentation/">Image segmentation with a U-Net-like architecture</a> <a class="nav-sublink2" href="/examples/vision/deeplabv3_plus/">Multiclass semantic segmentation using DeepLabV3+</a> <a class="nav-sublink2" href="/examples/vision/basnet_segmentation/">Highly accurate boundaries segmentation using BASNet</a> <a class="nav-sublink2" href="/examples/vision/fully_convolutional_network/">Image Segmentation using Composable Fully-Convolutional Networks</a> <a class="nav-sublink2" href="/examples/vision/retinanet/">Object Detection with RetinaNet</a> <a class="nav-sublink2" href="/examples/vision/keypoint_detection/">Keypoint Detection with Transfer Learning</a> <a class="nav-sublink2" href="/examples/vision/object_detection_using_vision_transformer/">Object detection with Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/3D_image_classification/">3D image classification from CT scans</a> <a class="nav-sublink2" href="/examples/vision/depth_estimation/">Monocular depth estimation</a> <a class="nav-sublink2" href="/examples/vision/nerf/">3D volumetric rendering with NeRF</a> <a class="nav-sublink2" href="/examples/vision/pointnet_segmentation/">Point cloud segmentation with PointNet</a> <a class="nav-sublink2" href="/examples/vision/pointnet/">Point cloud classification</a> <a class="nav-sublink2" href="/examples/vision/captcha_ocr/">OCR model for reading Captchas</a> <a class="nav-sublink2" href="/examples/vision/handwriting_recognition/">Handwriting recognition</a> <a class="nav-sublink2" href="/examples/vision/autoencoder/">Convolutional autoencoder for image denoising</a> <a class="nav-sublink2" href="/examples/vision/mirnet/">Low-light image enhancement using MIRNet</a> <a class="nav-sublink2" href="/examples/vision/super_resolution_sub_pixel/">Image Super-Resolution using an Efficient Sub-Pixel CNN</a> <a class="nav-sublink2" href="/examples/vision/edsr/">Enhanced Deep Residual Networks for single-image super-resolution</a> <a class="nav-sublink2" href="/examples/vision/zero_dce/">Zero-DCE for low-light image enhancement</a> <a class="nav-sublink2" href="/examples/vision/cutmix/">CutMix data augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/mixup/">MixUp augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/randaugment/">RandAugment for Image Classification for Improved Robustness</a> <a class="nav-sublink2" href="/examples/vision/image_captioning/">Image captioning</a> <a class="nav-sublink2" href="/examples/vision/nl_image_search/">Natural language image search with a Dual Encoder</a> <a class="nav-sublink2" href="/examples/vision/visualizing_what_convnets_learn/">Visualizing what convnets learn</a> <a class="nav-sublink2" href="/examples/vision/integrated_gradients/">Model interpretability with Integrated Gradients</a> <a class="nav-sublink2" href="/examples/vision/probing_vits/">Investigating Vision Transformer representations</a> <a class="nav-sublink2" href="/examples/vision/grad_cam/">Grad-CAM class activation visualization</a> <a class="nav-sublink2" href="/examples/vision/near_dup_search/">Near-duplicate image search</a> <a class="nav-sublink2" href="/examples/vision/semantic_image_clustering/">Semantic Image Clustering</a> <a class="nav-sublink2" href="/examples/vision/siamese_contrastive/">Image similarity estimation using a Siamese Network with a contrastive loss</a> <a class="nav-sublink2" href="/examples/vision/siamese_network/">Image similarity estimation using a Siamese Network with a triplet loss</a> <a class="nav-sublink2" href="/examples/vision/metric_learning/">Metric learning for image similarity search</a> <a class="nav-sublink2" href="/examples/vision/metric_learning_tf_similarity/">Metric learning for image similarity search using TensorFlow Similarity</a> <a class="nav-sublink2" href="/examples/vision/nnclr/">Self-supervised contrastive learning with NNCLR</a> <a class="nav-sublink2" href="/examples/vision/video_classification/">Video Classification with a CNN-RNN Architecture</a> <a class="nav-sublink2" href="/examples/vision/conv_lstm/">Next-Frame Video Prediction with Convolutional LSTMs</a> <a class="nav-sublink2 active" href="/examples/vision/video_transformers/">Video Classification with Transformers</a> <a class="nav-sublink2" href="/examples/vision/vivit/">Video Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/bit/">Image Classification using BigTransfer (BiT)</a> <a class="nav-sublink2" href="/examples/vision/gradient_centralization/">Gradient Centralization for Better Training Performance</a> <a class="nav-sublink2" href="/examples/vision/token_learner/">Learning to tokenize in Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/knowledge_distillation/">Knowledge Distillation</a> <a class="nav-sublink2" href="/examples/vision/fixres/">FixRes: Fixing train-test resolution discrepancy</a> <a class="nav-sublink2" href="/examples/vision/cait/">Class Attention Image Transformers with LayerScale</a> <a class="nav-sublink2" href="/examples/vision/patch_convnet/">Augmenting convnets with aggregated attention</a> <a class="nav-sublink2" href="/examples/vision/learnable_resizer/">Learning to Resize</a> <a class="nav-sublink2" href="/examples/vision/adamatch/">Semi-supervision and domain adaptation with AdaMatch</a> <a class="nav-sublink2" href="/examples/vision/barlow_twins/">Barlow Twins for Contrastive SSL</a> <a class="nav-sublink2" href="/examples/vision/consistency_training/">Consistency training with supervision</a> <a class="nav-sublink2" href="/examples/vision/deit/">Distilling Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/focal_modulation_network/">Focal Modulation: A replacement for Self-Attention</a> <a class="nav-sublink2" href="/examples/vision/forwardforward/">Using the Forward-Forward Algorithm for Image Classification</a> <a class="nav-sublink2" href="/examples/vision/masked_image_modeling/">Masked image modeling with Autoencoders</a> <a class="nav-sublink2" href="/examples/vision/sam/">Segment Anything Model with 🤗Transformers</a> <a class="nav-sublink2" href="/examples/vision/segformer/">Semantic segmentation with SegFormer and Hugging Face Transformers</a> <a class="nav-sublink2" href="/examples/vision/simsiam/">Self-supervised contrastive learning with SimSiam</a> <a class="nav-sublink2" href="/examples/vision/supervised-contrastive-learning/">Supervised Contrastive Learning</a> <a class="nav-sublink2" href="/examples/vision/temporal_latent_bottleneck/">When Recurrence meets Transformers</a> <a class="nav-sublink2" href="/examples/vision/yolov8/">Efficient Object Detection with YOLOV8 and KerasCV</a> <a class="nav-sublink" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparameter Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> <a class="nav-link" href="/keras_cv/" role="tab" aria-selected="">KerasCV: Computer Vision Workflows</a> <a class="nav-link" href="/keras_nlp/" role="tab" aria-selected="">KerasNLP: Natural Language Workflows</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/vision/'>Computer Vision</a> / Video Classification with Transformers </div> <div class='k-content'> <h1 id="video-classification-with-transformers">Video Classification with Transformers</h1> <p><strong>Author:</strong> <a href="https://twitter.com/RisingSayak">Sayak Paul</a><br> <strong>Date created:</strong> 2021/06/08<br> <strong>Last modified:</strong> 2023/22/07<br> <strong>Description:</strong> Training a video classifier with hybrid transformers.</p> <div class='example_version_banner keras_3'>ⓘ This example uses Keras 3</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/vision/ipynb/video_transformers.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/examples/vision/video_transformers.py"><strong>GitHub source</strong></a></p> <p>This example is a follow-up to the <a href="https://keras.io/examples/vision/video_classification/">Video Classification with a CNN-RNN Architecture</a> example. This time, we will be using a Transformer-based model (<a href="https://arxiv.org/abs/1706.03762">Vaswani et al.</a>) to classify videos. You can follow <a href="https://livebook.manning.com/book/deep-learning-with-python-second-edition/chapter-11">this book chapter</a> in case you need an introduction to Transformers (with code). After reading this example, you will know how to develop hybrid Transformer-based models for video classification that operate on CNN feature maps.</p> <div class="codehilite"><pre><span></span><code><span class="err">!</span><span class="n">pip</span> <span class="n">install</span> <span class="o">-</span><span class="n">q</span> <span class="n">git</span><span class="o">+</span><span class="n">https</span><span class="p">:</span><span class="o">//</span><span class="n">github</span><span class="o">.</span><span class="n">com</span><span class="o">/</span><span class="n">tensorflow</span><span class="o">/</span><span class="n">docs</span> </code></pre></div> <hr /> <h2 id="data-collection">Data collection</h2> <p>As done in the <a href="https://keras.io/examples/vision/video_classification/">predecessor</a> to this example, we will be using a subsampled version of the <a href="https://www.crcv.ucf.edu/data/UCF101.php">UCF101 dataset</a>, a well-known benchmark dataset. In case you want to operate on a larger subsample or even the entire dataset, please refer to <a href="https://colab.research.google.com/github/sayakpaul/Action-Recognition-in-TensorFlow/blob/main/Data_Preparation_UCF101.ipynb">this notebook</a>.</p> <div class="codehilite"><pre><span></span><code><span class="err">!</span><span class="n">wget</span> <span class="o">-</span><span class="n">q</span> <span class="n">https</span><span class="p">:</span><span class="o">//</span><span class="n">github</span><span class="o">.</span><span class="n">com</span><span class="o">/</span><span class="n">sayakpaul</span><span class="o">/</span><span class="n">Action</span><span class="o">-</span><span class="n">Recognition</span><span class="o">-</span><span class="ow">in</span><span class="o">-</span><span class="n">TensorFlow</span><span class="o">/</span><span class="n">releases</span><span class="o">/</span><span class="n">download</span><span class="o">/</span><span class="n">v1</span><span class="mf">.0.0</span><span class="o">/</span><span class="n">ucf101_top5</span><span class="o">.</span><span class="n">tar</span><span class="o">.</span><span class="n">gz</span> <span class="err">!</span><span class="n">tar</span> <span class="o">-</span><span class="n">xf</span> <span class="n">ucf101_top5</span><span class="o">.</span><span class="n">tar</span><span class="o">.</span><span class="n">gz</span> </code></pre></div> <hr /> <h2 id="setup">Setup</h2> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">os</span> <span class="kn">import</span> <span class="nn">keras</span> <span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">layers</span> <span class="kn">from</span> <span class="nn">keras.applications.densenet</span> <span class="kn">import</span> <span class="n">DenseNet121</span> <span class="kn">from</span> <span class="nn">tensorflow_docs.vis</span> <span class="kn">import</span> <span class="n">embed</span> <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span> <span class="kn">import</span> <span class="nn">pandas</span> <span class="k">as</span> <span class="nn">pd</span> <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> <span class="kn">import</span> <span class="nn">imageio</span> <span class="kn">import</span> <span class="nn">cv2</span> </code></pre></div> <hr /> <h2 id="define-hyperparameters">Define hyperparameters</h2> <div class="codehilite"><pre><span></span><code><span class="n">MAX_SEQ_LENGTH</span> <span class="o">=</span> <span class="mi">20</span> <span class="n">NUM_FEATURES</span> <span class="o">=</span> <span class="mi">1024</span> <span class="n">IMG_SIZE</span> <span class="o">=</span> <span class="mi">128</span> <span class="n">EPOCHS</span> <span class="o">=</span> <span class="mi">5</span> </code></pre></div> <hr /> <h2 id="data-preparation">Data preparation</h2> <p>We will mostly be following the same data preparation steps in this example, except for the following changes:</p> <ul> <li>We reduce the image size to 128x128 instead of 224x224 to speed up computation.</li> <li>Instead of using a pre-trained <a href="https://arxiv.org/abs/1512.00567">InceptionV3</a> network, we use a pre-trained <a href="http://openaccess.thecvf.com/content_cvpr_2017/papers/Huang_Densely_Connected_Convolutional_CVPR_2017_paper.pdf">DenseNet121</a> for feature extraction.</li> <li>We directly pad shorter videos to length <code>MAX_SEQ_LENGTH</code>.</li> </ul> <p>First, let's load up the <a href="https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html">DataFrames</a>.</p> <div class="codehilite"><pre><span></span><code><span class="n">train_df</span> <span class="o">=</span> <span class="n">pd</span><span class="o">.</span><span class="n">read_csv</span><span class="p">(</span><span class="s2">"train.csv"</span><span class="p">)</span> <span class="n">test_df</span> <span class="o">=</span> <span class="n">pd</span><span class="o">.</span><span class="n">read_csv</span><span class="p">(</span><span class="s2">"test.csv"</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Total videos for training: </span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">train_df</span><span class="p">)</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Total videos for testing: </span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">test_df</span><span class="p">)</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> <span class="n">center_crop_layer</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">CenterCrop</span><span class="p">(</span><span class="n">IMG_SIZE</span><span class="p">,</span> <span class="n">IMG_SIZE</span><span class="p">)</span> <span class="k">def</span> <span class="nf">crop_center</span><span class="p">(</span><span class="n">frame</span><span class="p">):</span> <span class="n">cropped</span> <span class="o">=</span> <span class="n">center_crop_layer</span><span class="p">(</span><span class="n">frame</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="o">...</span><span class="p">])</span> <span class="n">cropped</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">convert_to_numpy</span><span class="p">(</span><span class="n">cropped</span><span class="p">)</span> <span class="n">cropped</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="n">cropped</span><span class="p">)</span> <span class="k">return</span> <span class="n">cropped</span> <span class="c1"># Following method is modified from this tutorial:</span> <span class="c1"># https://www.tensorflow.org/hub/tutorials/action_recognition_with_tf_hub</span> <span class="k">def</span> <span class="nf">load_video</span><span class="p">(</span><span class="n">path</span><span class="p">,</span> <span class="n">max_frames</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">offload_to_cpu</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span> <span class="n">cap</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">VideoCapture</span><span class="p">(</span><span class="n">path</span><span class="p">)</span> <span class="n">frames</span> <span class="o">=</span> <span class="p">[]</span> <span class="k">try</span><span class="p">:</span> <span class="k">while</span> <span class="kc">True</span><span class="p">:</span> <span class="n">ret</span><span class="p">,</span> <span class="n">frame</span> <span class="o">=</span> <span class="n">cap</span><span class="o">.</span><span class="n">read</span><span class="p">()</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">ret</span><span class="p">:</span> <span class="k">break</span> <span class="n">frame</span> <span class="o">=</span> <span class="n">frame</span><span class="p">[:,</span> <span class="p">:,</span> <span class="p">[</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">]]</span> <span class="n">frame</span> <span class="o">=</span> <span class="n">crop_center</span><span class="p">(</span><span class="n">frame</span><span class="p">)</span> <span class="k">if</span> <span class="n">offload_to_cpu</span> <span class="ow">and</span> <span class="n">keras</span><span class="o">.</span><span class="n">backend</span><span class="o">.</span><span class="n">backend</span><span class="p">()</span> <span class="o">==</span> <span class="s2">"torch"</span><span class="p">:</span> <span class="n">frame</span> <span class="o">=</span> <span class="n">frame</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s2">"cpu"</span><span class="p">)</span> <span class="n">frames</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">frame</span><span class="p">)</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">frames</span><span class="p">)</span> <span class="o">==</span> <span class="n">max_frames</span><span class="p">:</span> <span class="k">break</span> <span class="k">finally</span><span class="p">:</span> <span class="n">cap</span><span class="o">.</span><span class="n">release</span><span class="p">()</span> <span class="k">if</span> <span class="n">offload_to_cpu</span> <span class="ow">and</span> <span class="n">keras</span><span class="o">.</span><span class="n">backend</span><span class="o">.</span><span class="n">backend</span><span class="p">()</span> <span class="o">==</span> <span class="s2">"torch"</span><span class="p">:</span> <span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">frame</span><span class="o">.</span><span class="n">to</span><span class="p">(</span><span class="s2">"cpu"</span><span class="p">)</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> <span class="k">for</span> <span class="n">frame</span> <span class="ow">in</span> <span class="n">frames</span><span class="p">])</span> <span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">frames</span><span class="p">)</span> <span class="k">def</span> <span class="nf">build_feature_extractor</span><span class="p">():</span> <span class="n">feature_extractor</span> <span class="o">=</span> <span class="n">DenseNet121</span><span class="p">(</span> <span class="n">weights</span><span class="o">=</span><span class="s2">"imagenet"</span><span class="p">,</span> <span class="n">include_top</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">pooling</span><span class="o">=</span><span class="s2">"avg"</span><span class="p">,</span> <span class="n">input_shape</span><span class="o">=</span><span class="p">(</span><span class="n">IMG_SIZE</span><span class="p">,</span> <span class="n">IMG_SIZE</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="p">)</span> <span class="n">preprocess_input</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">applications</span><span class="o">.</span><span class="n">densenet</span><span class="o">.</span><span class="n">preprocess_input</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">((</span><span class="n">IMG_SIZE</span><span class="p">,</span> <span class="n">IMG_SIZE</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span> <span class="n">preprocessed</span> <span class="o">=</span> <span class="n">preprocess_input</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">feature_extractor</span><span class="p">(</span><span class="n">preprocessed</span><span class="p">)</span> <span class="k">return</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"feature_extractor"</span><span class="p">)</span> <span class="n">feature_extractor</span> <span class="o">=</span> <span class="n">build_feature_extractor</span><span class="p">()</span> <span class="c1"># Label preprocessing with StringLookup.</span> <span class="n">label_processor</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">StringLookup</span><span class="p">(</span> <span class="n">num_oov_indices</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">vocabulary</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">unique</span><span class="p">(</span><span class="n">train_df</span><span class="p">[</span><span class="s2">"tag"</span><span class="p">]),</span> <span class="n">mask_token</span><span class="o">=</span><span class="kc">None</span> <span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="n">label_processor</span><span class="o">.</span><span class="n">get_vocabulary</span><span class="p">())</span> <span class="k">def</span> <span class="nf">prepare_all_videos</span><span class="p">(</span><span class="n">df</span><span class="p">,</span> <span class="n">root_dir</span><span class="p">):</span> <span class="n">num_samples</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">df</span><span class="p">)</span> <span class="n">video_paths</span> <span class="o">=</span> <span class="n">df</span><span class="p">[</span><span class="s2">"video_name"</span><span class="p">]</span><span class="o">.</span><span class="n">values</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">df</span><span class="p">[</span><span class="s2">"tag"</span><span class="p">]</span><span class="o">.</span><span class="n">values</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">label_processor</span><span class="p">(</span><span class="n">labels</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="kc">None</span><span class="p">])</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> <span class="c1"># `frame_features` are what we will feed to our sequence model.</span> <span class="n">frame_features</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">num_samples</span><span class="p">,</span> <span class="n">MAX_SEQ_LENGTH</span><span class="p">,</span> <span class="n">NUM_FEATURES</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"float32"</span> <span class="p">)</span> <span class="c1"># For each video.</span> <span class="k">for</span> <span class="n">idx</span><span class="p">,</span> <span class="n">path</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">video_paths</span><span class="p">):</span> <span class="c1"># Gather all its frames and add a batch dimension.</span> <span class="n">frames</span> <span class="o">=</span> <span class="n">load_video</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">root_dir</span><span class="p">,</span> <span class="n">path</span><span class="p">))</span> <span class="c1"># Pad shorter videos.</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">frames</span><span class="p">)</span> <span class="o"><</span> <span class="n">MAX_SEQ_LENGTH</span><span class="p">:</span> <span class="n">diff</span> <span class="o">=</span> <span class="n">MAX_SEQ_LENGTH</span> <span class="o">-</span> <span class="nb">len</span><span class="p">(</span><span class="n">frames</span><span class="p">)</span> <span class="n">padding</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">diff</span><span class="p">,</span> <span class="n">IMG_SIZE</span><span class="p">,</span> <span class="n">IMG_SIZE</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span> <span class="n">frames</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">(</span><span class="n">frames</span><span class="p">,</span> <span class="n">padding</span><span class="p">)</span> <span class="n">frames</span> <span class="o">=</span> <span class="n">frames</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="o">...</span><span class="p">]</span> <span class="c1"># Initialize placeholder to store the features of the current video.</span> <span class="n">temp_frame_features</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">MAX_SEQ_LENGTH</span><span class="p">,</span> <span class="n">NUM_FEATURES</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"float32"</span> <span class="p">)</span> <span class="c1"># Extract features from the frames of the current video.</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">batch</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">frames</span><span class="p">):</span> <span class="n">video_length</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="n">length</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">MAX_SEQ_LENGTH</span><span class="p">,</span> <span class="n">video_length</span><span class="p">)</span> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">length</span><span class="p">):</span> <span class="k">if</span> <span class="n">np</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">batch</span><span class="p">[</span><span class="n">j</span><span class="p">,</span> <span class="p">:])</span> <span class="o">></span> <span class="mf">0.0</span><span class="p">:</span> <span class="n">temp_frame_features</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">,</span> <span class="p">:]</span> <span class="o">=</span> <span class="n">feature_extractor</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span> <span class="n">batch</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="n">j</span><span class="p">,</span> <span class="p">:]</span> <span class="p">)</span> <span class="k">else</span><span class="p">:</span> <span class="n">temp_frame_features</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">,</span> <span class="p">:]</span> <span class="o">=</span> <span class="mf">0.0</span> <span class="n">frame_features</span><span class="p">[</span><span class="n">idx</span><span class="p">,]</span> <span class="o">=</span> <span class="n">temp_frame_features</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span> <span class="k">return</span> <span class="n">frame_features</span><span class="p">,</span> <span class="n">labels</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Total videos for training: 594 Total videos for testing: 224 ['CricketShot', 'PlayingCello', 'Punch', 'ShavingBeard', 'TennisSwing'] </code></pre></div> </div> <p>Calling <code>prepare_all_videos()</code> on <code>train_df</code> and <code>test_df</code> takes ~20 minutes to complete. For this reason, to save time, here we download already preprocessed NumPy arrays:</p> <div class="codehilite"><pre><span></span><code><span class="err">!!</span><span class="n">wget</span> <span class="o">-</span><span class="n">q</span> <span class="n">https</span><span class="p">:</span><span class="o">//</span><span class="n">git</span><span class="o">.</span><span class="n">io</span><span class="o">/</span><span class="n">JZmf4</span> <span class="o">-</span><span class="n">O</span> <span class="n">top5_data_prepared</span><span class="o">.</span><span class="n">tar</span><span class="o">.</span><span class="n">gz</span> <span class="err">!!</span><span class="n">tar</span> <span class="o">-</span><span class="n">xf</span> <span class="n">top5_data_prepared</span><span class="o">.</span><span class="n">tar</span><span class="o">.</span><span class="n">gz</span> </code></pre></div> <div class="codehilite"><pre><span></span><code><span class="n">train_data</span><span class="p">,</span> <span class="n">train_labels</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s2">"train_data.npy"</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s2">"train_labels.npy"</span><span class="p">)</span> <span class="n">test_data</span><span class="p">,</span> <span class="n">test_labels</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s2">"test_data.npy"</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s2">"test_labels.npy"</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Frame features in train set: </span><span class="si">{</span><span class="n">train_data</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>[] Frame features in train set: (594, 20, 1024) </code></pre></div> </div> <hr /> <h2 id="building-the-transformerbased-model">Building the Transformer-based model</h2> <p>We will be building on top of the code shared in <a href="https://livebook.manning.com/book/deep-learning-with-python-second-edition/chapter-11">this book chapter</a> of <a href="https://www.manning.com/books/deep-learning-with-python">Deep Learning with Python (Second ed.)</a> by François Chollet.</p> <p>First, self-attention layers that form the basic blocks of a Transformer are order-agnostic. Since videos are ordered sequences of frames, we need our Transformer model to take into account order information. We do this via <strong>positional encoding</strong>. We simply embed the positions of the frames present inside videos with an <a href="https://keras.io/api/layers/core_layers/embedding"><code>Embedding</code> layer</a>. We then add these positional embeddings to the precomputed CNN feature maps.</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">PositionalEmbedding</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">sequence_length</span><span class="p">,</span> <span class="n">output_dim</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">position_embeddings</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span> <span class="n">input_dim</span><span class="o">=</span><span class="n">sequence_length</span><span class="p">,</span> <span class="n">output_dim</span><span class="o">=</span><span class="n">output_dim</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">sequence_length</span> <span class="o">=</span> <span class="n">sequence_length</span> <span class="bp">self</span><span class="o">.</span><span class="n">output_dim</span> <span class="o">=</span> <span class="n">output_dim</span> <span class="k">def</span> <span class="nf">build</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_shape</span><span class="p">):</span> <span class="bp">self</span><span class="o">.</span><span class="n">position_embeddings</span><span class="o">.</span><span class="n">build</span><span class="p">(</span><span class="n">input_shape</span><span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="c1"># The inputs are of shape: `(batch_size, frames, num_features)`</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">compute_dtype</span><span class="p">)</span> <span class="n">length</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">inputs</span><span class="p">)[</span><span class="mi">1</span><span class="p">]</span> <span class="n">positions</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">start</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">stop</span><span class="o">=</span><span class="n">length</span><span class="p">,</span> <span class="n">step</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="n">embedded_positions</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">position_embeddings</span><span class="p">(</span><span class="n">positions</span><span class="p">)</span> <span class="k">return</span> <span class="n">inputs</span> <span class="o">+</span> <span class="n">embedded_positions</span> </code></pre></div> <p>Now, we can create a subclassed layer for the Transformer.</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">TransformerEncoder</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">,</span> <span class="n">dense_dim</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">embed_dim</span> <span class="o">=</span> <span class="n">embed_dim</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense_dim</span> <span class="o">=</span> <span class="n">dense_dim</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_heads</span> <span class="o">=</span> <span class="n">num_heads</span> <span class="bp">self</span><span class="o">.</span><span class="n">attention</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">MultiHeadAttention</span><span class="p">(</span> <span class="n">num_heads</span><span class="o">=</span><span class="n">num_heads</span><span class="p">,</span> <span class="n">key_dim</span><span class="o">=</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">dropout</span><span class="o">=</span><span class="mf">0.3</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense_proj</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span> <span class="p">[</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">dense_dim</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">activations</span><span class="o">.</span><span class="n">gelu</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">embed_dim</span><span class="p">),</span> <span class="p">]</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">layernorm_1</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">LayerNormalization</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">layernorm_2</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">LayerNormalization</span><span class="p">()</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">mask</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="n">attention_output</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">attention</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">attention_mask</span><span class="o">=</span><span class="n">mask</span><span class="p">)</span> <span class="n">proj_input</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">layernorm_1</span><span class="p">(</span><span class="n">inputs</span> <span class="o">+</span> <span class="n">attention_output</span><span class="p">)</span> <span class="n">proj_output</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense_proj</span><span class="p">(</span><span class="n">proj_input</span><span class="p">)</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">layernorm_2</span><span class="p">(</span><span class="n">proj_input</span> <span class="o">+</span> <span class="n">proj_output</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="utility-functions-for-training">Utility functions for training</h2> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">get_compiled_model</span><span class="p">(</span><span class="n">shape</span><span class="p">):</span> <span class="n">sequence_length</span> <span class="o">=</span> <span class="n">MAX_SEQ_LENGTH</span> <span class="n">embed_dim</span> <span class="o">=</span> <span class="n">NUM_FEATURES</span> <span class="n">dense_dim</span> <span class="o">=</span> <span class="mi">4</span> <span class="n">num_heads</span> <span class="o">=</span> <span class="mi">1</span> <span class="n">classes</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">label_processor</span><span class="o">.</span><span class="n">get_vocabulary</span><span class="p">())</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="n">shape</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">PositionalEmbedding</span><span class="p">(</span> <span class="n">sequence_length</span><span class="p">,</span> <span class="n">embed_dim</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"frame_position_embedding"</span> <span class="p">)(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">TransformerEncoder</span><span class="p">(</span><span class="n">embed_dim</span><span class="p">,</span> <span class="n">dense_dim</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"transformer_layer"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">GlobalMaxPooling1D</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="mf">0.5</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">classes</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"softmax"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">optimizer</span><span class="o">=</span><span class="s2">"adam"</span><span class="p">,</span> <span class="n">loss</span><span class="o">=</span><span class="s2">"sparse_categorical_crossentropy"</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="s2">"accuracy"</span><span class="p">],</span> <span class="p">)</span> <span class="k">return</span> <span class="n">model</span> <span class="k">def</span> <span class="nf">run_experiment</span><span class="p">():</span> <span class="n">filepath</span> <span class="o">=</span> <span class="s2">"/tmp/video_classifier.weights.h5"</span> <span class="n">checkpoint</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">callbacks</span><span class="o">.</span><span class="n">ModelCheckpoint</span><span class="p">(</span> <span class="n">filepath</span><span class="p">,</span> <span class="n">save_weights_only</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">save_best_only</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">verbose</span><span class="o">=</span><span class="mi">1</span> <span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">get_compiled_model</span><span class="p">(</span><span class="n">train_data</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">:])</span> <span class="n">history</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span> <span class="n">train_data</span><span class="p">,</span> <span class="n">train_labels</span><span class="p">,</span> <span class="n">validation_split</span><span class="o">=</span><span class="mf">0.15</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="n">EPOCHS</span><span class="p">,</span> <span class="n">callbacks</span><span class="o">=</span><span class="p">[</span><span class="n">checkpoint</span><span class="p">],</span> <span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">load_weights</span><span class="p">(</span><span class="n">filepath</span><span class="p">)</span> <span class="n">_</span><span class="p">,</span> <span class="n">accuracy</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">test_data</span><span class="p">,</span> <span class="n">test_labels</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Test accuracy: </span><span class="si">{</span><span class="nb">round</span><span class="p">(</span><span class="n">accuracy</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="mi">100</span><span class="p">,</span><span class="w"> </span><span class="mi">2</span><span class="p">)</span><span class="si">}</span><span class="s2">%"</span><span class="p">)</span> <span class="k">return</span> <span class="n">model</span> </code></pre></div> <hr /> <h2 id="model-training-and-inference">Model training and inference</h2> <div class="codehilite"><pre><span></span><code><span class="n">trained_model</span> <span class="o">=</span> <span class="n">run_experiment</span><span class="p">()</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 1/5 16/16 ━━━━━━━━━━━━━━━━━━━━ 0s 160ms/step - accuracy: 0.5286 - loss: 2.6762 Epoch 1: val_loss improved from inf to 7.75026, saving model to /tmp/video_classifier.weights.h5 16/16 ━━━━━━━━━━━━━━━━━━━━ 7s 272ms/step - accuracy: 0.5387 - loss: 2.6139 - val_accuracy: 0.0000e+00 - val_loss: 7.7503 Epoch 2/5 15/16 ━━━━━━━━━━━━━━━━━━[37m━━ 0s 4ms/step - accuracy: 0.9396 - loss: 0.2264 Epoch 2: val_loss improved from 7.75026 to 1.96635, saving model to /tmp/video_classifier.weights.h5 16/16 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step - accuracy: 0.9406 - loss: 0.2186 - val_accuracy: 0.4000 - val_loss: 1.9664 Epoch 3/5 14/16 ━━━━━━━━━━━━━━━━━[37m━━━ 0s 4ms/step - accuracy: 0.9823 - loss: 0.0384 Epoch 3: val_loss did not improve from 1.96635 16/16 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9822 - loss: 0.0391 - val_accuracy: 0.3667 - val_loss: 3.7076 Epoch 4/5 15/16 ━━━━━━━━━━━━━━━━━━[37m━━ 0s 4ms/step - accuracy: 0.9825 - loss: 0.0681 Epoch 4: val_loss did not improve from 1.96635 16/16 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step - accuracy: 0.9831 - loss: 0.0674 - val_accuracy: 0.4222 - val_loss: 3.7957 Epoch 5/5 15/16 ━━━━━━━━━━━━━━━━━━[37m━━ 0s 4ms/step - accuracy: 1.0000 - loss: 0.0035 Epoch 5: val_loss improved from 1.96635 to 1.56071, saving model to /tmp/video_classifier.weights.h5 16/16 ━━━━━━━━━━━━━━━━━━━━ 0s 15ms/step - accuracy: 1.0000 - loss: 0.0033 - val_accuracy: 0.6333 - val_loss: 1.5607 7/7 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - accuracy: 0.9286 - loss: 0.4434 Test accuracy: 89.29% </code></pre></div> </div> <p><strong>Note</strong>: This model has ~4.23 Million parameters, which is way more than the sequence model (99918 parameters) we used in the prequel of this example. This kind of Transformer model works best with a larger dataset and a longer pre-training schedule.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">prepare_single_video</span><span class="p">(</span><span class="n">frames</span><span class="p">):</span> <span class="n">frame_features</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">MAX_SEQ_LENGTH</span><span class="p">,</span> <span class="n">NUM_FEATURES</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"float32"</span><span class="p">)</span> <span class="c1"># Pad shorter videos.</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">frames</span><span class="p">)</span> <span class="o"><</span> <span class="n">MAX_SEQ_LENGTH</span><span class="p">:</span> <span class="n">diff</span> <span class="o">=</span> <span class="n">MAX_SEQ_LENGTH</span> <span class="o">-</span> <span class="nb">len</span><span class="p">(</span><span class="n">frames</span><span class="p">)</span> <span class="n">padding</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">diff</span><span class="p">,</span> <span class="n">IMG_SIZE</span><span class="p">,</span> <span class="n">IMG_SIZE</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span> <span class="n">frames</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">(</span><span class="n">frames</span><span class="p">,</span> <span class="n">padding</span><span class="p">)</span> <span class="n">frames</span> <span class="o">=</span> <span class="n">frames</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="o">...</span><span class="p">]</span> <span class="c1"># Extract features from the frames of the current video.</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">batch</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">frames</span><span class="p">):</span> <span class="n">video_length</span> <span class="o">=</span> <span class="n">batch</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="n">length</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">MAX_SEQ_LENGTH</span><span class="p">,</span> <span class="n">video_length</span><span class="p">)</span> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">length</span><span class="p">):</span> <span class="k">if</span> <span class="n">np</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">batch</span><span class="p">[</span><span class="n">j</span><span class="p">,</span> <span class="p">:])</span> <span class="o">></span> <span class="mf">0.0</span><span class="p">:</span> <span class="n">frame_features</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">,</span> <span class="p">:]</span> <span class="o">=</span> <span class="n">feature_extractor</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">batch</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="n">j</span><span class="p">,</span> <span class="p">:])</span> <span class="k">else</span><span class="p">:</span> <span class="n">frame_features</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">,</span> <span class="p">:]</span> <span class="o">=</span> <span class="mf">0.0</span> <span class="k">return</span> <span class="n">frame_features</span> <span class="k">def</span> <span class="nf">predict_action</span><span class="p">(</span><span class="n">path</span><span class="p">):</span> <span class="n">class_vocab</span> <span class="o">=</span> <span class="n">label_processor</span><span class="o">.</span><span class="n">get_vocabulary</span><span class="p">()</span> <span class="n">frames</span> <span class="o">=</span> <span class="n">load_video</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="s2">"test"</span><span class="p">,</span> <span class="n">path</span><span class="p">),</span> <span class="n">offload_to_cpu</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="n">frame_features</span> <span class="o">=</span> <span class="n">prepare_single_video</span><span class="p">(</span><span class="n">frames</span><span class="p">)</span> <span class="n">probabilities</span> <span class="o">=</span> <span class="n">trained_model</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">frame_features</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span> <span class="n">plot_x_axis</span><span class="p">,</span> <span class="n">plot_y_axis</span> <span class="o">=</span> <span class="p">[],</span> <span class="p">[]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">np</span><span class="o">.</span><span class="n">argsort</span><span class="p">(</span><span class="n">probabilities</span><span class="p">)[::</span><span class="o">-</span><span class="mi">1</span><span class="p">]:</span> <span class="n">plot_x_axis</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">class_vocab</span><span class="p">[</span><span class="n">i</span><span class="p">])</span> <span class="n">plot_y_axis</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">probabilities</span><span class="p">[</span><span class="n">i</span><span class="p">])</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">" </span><span class="si">{</span><span class="n">class_vocab</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="si">}</span><span class="s2">: </span><span class="si">{</span><span class="n">probabilities</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="mi">100</span><span class="si">:</span><span class="s2">5.2f</span><span class="si">}</span><span class="s2">%"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">bar</span><span class="p">(</span><span class="n">plot_x_axis</span><span class="p">,</span> <span class="n">plot_y_axis</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="n">plot_x_axis</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s2">"class_label"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s2">"Probability"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> <span class="k">return</span> <span class="n">frames</span> <span class="c1"># This utility is for visualization.</span> <span class="c1"># Referenced from:</span> <span class="c1"># https://www.tensorflow.org/hub/tutorials/action_recognition_with_tf_hub</span> <span class="k">def</span> <span class="nf">to_gif</span><span class="p">(</span><span class="n">images</span><span class="p">):</span> <span class="n">converted_images</span> <span class="o">=</span> <span class="n">images</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">)</span> <span class="n">imageio</span><span class="o">.</span><span class="n">mimsave</span><span class="p">(</span><span class="s2">"animation.gif"</span><span class="p">,</span> <span class="n">converted_images</span><span class="p">,</span> <span class="n">fps</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span> <span class="k">return</span> <span class="n">embed</span><span class="o">.</span><span class="n">embed_file</span><span class="p">(</span><span class="s2">"animation.gif"</span><span class="p">)</span> <span class="n">test_video</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="n">test_df</span><span class="p">[</span><span class="s2">"video_name"</span><span class="p">]</span><span class="o">.</span><span class="n">values</span><span class="o">.</span><span class="n">tolist</span><span class="p">())</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Test video path: </span><span class="si">{</span><span class="n">test_video</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> <span class="n">test_frames</span> <span class="o">=</span> <span class="n">predict_action</span><span class="p">(</span><span class="n">test_video</span><span class="p">)</span> <span class="n">to_gif</span><span class="p">(</span><span class="n">test_frames</span><span class="p">[:</span><span class="n">MAX_SEQ_LENGTH</span><span class="p">])</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Test video path: v_ShavingBeard_g03_c02.avi 1/1 ━━━━━━━━━━━━━━━━━━━━ 20s 20s/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 10ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 11ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 12ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 557ms/step ShavingBeard: 100.00% Punch: 0.00% CricketShot: 0.00% TennisSwing: 0.00% PlayingCello: 0.00% </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/video_transformers/video_transformers_23_1.png" /></p> <p><img src=""/></p> <p>The performance of our model is far from optimal, because it was trained on a small dataset.</p> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#video-classification-with-transformers'>Video Classification with Transformers</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#data-collection'>Data collection</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#setup'>Setup</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#define-hyperparameters'>Define hyperparameters</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#data-preparation'>Data preparation</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#building-the-transformerbased-model'>Building the Transformer-based model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#utility-functions-for-training'>Utility functions for training</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#model-training-and-inference'>Model training and inference</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>