CINXE.COM
Object detection with Vision Transformers
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/vision/object_detection_using_vision_transformer/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Object detection with Vision Transformers"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Object detection with Vision Transformers"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Object detection with Vision Transformers</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink active" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink2" href="/examples/vision/image_classification_from_scratch/">Image classification from scratch</a> <a class="nav-sublink2" href="/examples/vision/mnist_convnet/">Simple MNIST convnet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_efficientnet_fine_tuning/">Image classification via fine-tuning with EfficientNet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_with_vision_transformer/">Image classification with Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/attention_mil_classification/">Classification using Attention-based Deep Multiple Instance Learning</a> <a class="nav-sublink2" href="/examples/vision/mlp_image_classification/">Image classification with modern MLP models</a> <a class="nav-sublink2" href="/examples/vision/mobilevit/">A mobile-friendly Transformer-based model for image classification</a> <a class="nav-sublink2" href="/examples/vision/xray_classification_with_tpus/">Pneumonia Classification on TPU</a> <a class="nav-sublink2" href="/examples/vision/cct/">Compact Convolutional Transformers</a> <a class="nav-sublink2" href="/examples/vision/convmixer/">Image classification with ConvMixer</a> <a class="nav-sublink2" href="/examples/vision/eanet/">Image classification with EANet (External Attention Transformer)</a> <a class="nav-sublink2" href="/examples/vision/involution/">Involutional neural networks</a> <a class="nav-sublink2" href="/examples/vision/perceiver_image_classification/">Image classification with Perceiver</a> <a class="nav-sublink2" href="/examples/vision/reptile/">Few-Shot learning with Reptile</a> <a class="nav-sublink2" href="/examples/vision/semisupervised_simclr/">Semi-supervised image classification using contrastive pretraining with SimCLR</a> <a class="nav-sublink2" href="/examples/vision/swin_transformers/">Image classification with Swin Transformers</a> <a class="nav-sublink2" href="/examples/vision/vit_small_ds/">Train a Vision Transformer on small datasets</a> <a class="nav-sublink2" href="/examples/vision/shiftvit/">A Vision Transformer without Attention</a> <a class="nav-sublink2" href="/examples/vision/image_classification_using_global_context_vision_transformer/">Image Classification using Global Context Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/oxford_pets_image_segmentation/">Image segmentation with a U-Net-like architecture</a> <a class="nav-sublink2" href="/examples/vision/deeplabv3_plus/">Multiclass semantic segmentation using DeepLabV3+</a> <a class="nav-sublink2" href="/examples/vision/basnet_segmentation/">Highly accurate boundaries segmentation using BASNet</a> <a class="nav-sublink2" href="/examples/vision/fully_convolutional_network/">Image Segmentation using Composable Fully-Convolutional Networks</a> <a class="nav-sublink2" href="/examples/vision/retinanet/">Object Detection with RetinaNet</a> <a class="nav-sublink2" href="/examples/vision/keypoint_detection/">Keypoint Detection with Transfer Learning</a> <a class="nav-sublink2 active" href="/examples/vision/object_detection_using_vision_transformer/">Object detection with Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/3D_image_classification/">3D image classification from CT scans</a> <a class="nav-sublink2" href="/examples/vision/depth_estimation/">Monocular depth estimation</a> <a class="nav-sublink2" href="/examples/vision/nerf/">3D volumetric rendering with NeRF</a> <a class="nav-sublink2" href="/examples/vision/pointnet_segmentation/">Point cloud segmentation with PointNet</a> <a class="nav-sublink2" href="/examples/vision/pointnet/">Point cloud classification</a> <a class="nav-sublink2" href="/examples/vision/captcha_ocr/">OCR model for reading Captchas</a> <a class="nav-sublink2" href="/examples/vision/handwriting_recognition/">Handwriting recognition</a> <a class="nav-sublink2" href="/examples/vision/autoencoder/">Convolutional autoencoder for image denoising</a> <a class="nav-sublink2" href="/examples/vision/mirnet/">Low-light image enhancement using MIRNet</a> <a class="nav-sublink2" href="/examples/vision/super_resolution_sub_pixel/">Image Super-Resolution using an Efficient Sub-Pixel CNN</a> <a class="nav-sublink2" href="/examples/vision/edsr/">Enhanced Deep Residual Networks for single-image super-resolution</a> <a class="nav-sublink2" href="/examples/vision/zero_dce/">Zero-DCE for low-light image enhancement</a> <a class="nav-sublink2" href="/examples/vision/cutmix/">CutMix data augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/mixup/">MixUp augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/randaugment/">RandAugment for Image Classification for Improved Robustness</a> <a class="nav-sublink2" href="/examples/vision/image_captioning/">Image captioning</a> <a class="nav-sublink2" href="/examples/vision/nl_image_search/">Natural language image search with a Dual Encoder</a> <a class="nav-sublink2" href="/examples/vision/visualizing_what_convnets_learn/">Visualizing what convnets learn</a> <a class="nav-sublink2" href="/examples/vision/integrated_gradients/">Model interpretability with Integrated Gradients</a> <a class="nav-sublink2" href="/examples/vision/probing_vits/">Investigating Vision Transformer representations</a> <a class="nav-sublink2" href="/examples/vision/grad_cam/">Grad-CAM class activation visualization</a> <a class="nav-sublink2" href="/examples/vision/near_dup_search/">Near-duplicate image search</a> <a class="nav-sublink2" href="/examples/vision/semantic_image_clustering/">Semantic Image Clustering</a> <a class="nav-sublink2" href="/examples/vision/siamese_contrastive/">Image similarity estimation using a Siamese Network with a contrastive loss</a> <a class="nav-sublink2" href="/examples/vision/siamese_network/">Image similarity estimation using a Siamese Network with a triplet loss</a> <a class="nav-sublink2" href="/examples/vision/metric_learning/">Metric learning for image similarity search</a> <a class="nav-sublink2" href="/examples/vision/metric_learning_tf_similarity/">Metric learning for image similarity search using TensorFlow Similarity</a> <a class="nav-sublink2" href="/examples/vision/nnclr/">Self-supervised contrastive learning with NNCLR</a> <a class="nav-sublink2" href="/examples/vision/video_classification/">Video Classification with a CNN-RNN Architecture</a> <a class="nav-sublink2" href="/examples/vision/conv_lstm/">Next-Frame Video Prediction with Convolutional LSTMs</a> <a class="nav-sublink2" href="/examples/vision/video_transformers/">Video Classification with Transformers</a> <a class="nav-sublink2" href="/examples/vision/vivit/">Video Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/bit/">Image Classification using BigTransfer (BiT)</a> <a class="nav-sublink2" href="/examples/vision/gradient_centralization/">Gradient Centralization for Better Training Performance</a> <a class="nav-sublink2" href="/examples/vision/token_learner/">Learning to tokenize in Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/knowledge_distillation/">Knowledge Distillation</a> <a class="nav-sublink2" href="/examples/vision/fixres/">FixRes: Fixing train-test resolution discrepancy</a> <a class="nav-sublink2" href="/examples/vision/cait/">Class Attention Image Transformers with LayerScale</a> <a class="nav-sublink2" href="/examples/vision/patch_convnet/">Augmenting convnets with aggregated attention</a> <a class="nav-sublink2" href="/examples/vision/learnable_resizer/">Learning to Resize</a> <a class="nav-sublink2" href="/examples/vision/adamatch/">Semi-supervision and domain adaptation with AdaMatch</a> <a class="nav-sublink2" href="/examples/vision/barlow_twins/">Barlow Twins for Contrastive SSL</a> <a class="nav-sublink2" href="/examples/vision/consistency_training/">Consistency training with supervision</a> <a class="nav-sublink2" href="/examples/vision/deit/">Distilling Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/focal_modulation_network/">Focal Modulation: A replacement for Self-Attention</a> <a class="nav-sublink2" href="/examples/vision/forwardforward/">Using the Forward-Forward Algorithm for Image Classification</a> <a class="nav-sublink2" href="/examples/vision/masked_image_modeling/">Masked image modeling with Autoencoders</a> <a class="nav-sublink2" href="/examples/vision/sam/">Segment Anything Model with 🤗Transformers</a> <a class="nav-sublink2" href="/examples/vision/segformer/">Semantic segmentation with SegFormer and Hugging Face Transformers</a> <a class="nav-sublink2" href="/examples/vision/simsiam/">Self-supervised contrastive learning with SimSiam</a> <a class="nav-sublink2" href="/examples/vision/supervised-contrastive-learning/">Supervised Contrastive Learning</a> <a class="nav-sublink2" href="/examples/vision/temporal_latent_bottleneck/">When Recurrence meets Transformers</a> <a class="nav-sublink2" href="/examples/vision/yolov8/">Efficient Object Detection with YOLOV8 and KerasCV</a> <a class="nav-sublink" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparam Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/vision/'>Computer Vision</a> / Object detection with Vision Transformers </div> <div class='k-content'> <h1 id="object-detection-with-vision-transformers">Object detection with Vision Transformers</h1> <p><strong>Author:</strong> <a href="https://www.linkedin.com/in/karan-dave-811413164/">Karan V. Dave</a><br> <strong>Date created:</strong> 2022/03/27<br> <strong>Last modified:</strong> 2023/11/20<br> <strong>Description:</strong> A simple Keras implementation of object detection using Vision Transformers.</p> <div class='example_version_banner keras_3'>ⓘ This example uses Keras 3</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/vision/ipynb/object_detection_using_vision_transformer.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/examples/vision/object_detection_using_vision_transformer.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="introduction">Introduction</h2> <p>The article <a href="https://arxiv.org/abs/2010.11929">Vision Transformer (ViT)</a> architecture by Alexey Dosovitskiy et al. demonstrates that a pure transformer applied directly to sequences of image patches can perform well on object detection tasks.</p> <p>In this Keras example, we implement an object detection ViT and we train it on the <a href="http://www.vision.caltech.edu/datasets/">Caltech 101 dataset</a> to detect an airplane in the given image.</p> <hr /> <h2 id="imports-and-setup">Imports and setup</h2> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">os</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s2">"KERAS_BACKEND"</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"jax"</span> <span class="c1"># @param ["tensorflow", "jax", "torch"]</span> <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> <span class="kn">import</span> <span class="nn">keras</span> <span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">layers</span> <span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">ops</span> <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span> <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> <span class="kn">import</span> <span class="nn">cv2</span> <span class="kn">import</span> <span class="nn">os</span> <span class="kn">import</span> <span class="nn">scipy.io</span> <span class="kn">import</span> <span class="nn">shutil</span> </code></pre></div> <hr /> <h2 id="prepare-dataset">Prepare dataset</h2> <p>We use the <a href="https://data.caltech.edu/records/mzrjq-6wc02">Caltech 101 Dataset</a>.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Path to images and annotations</span> <span class="n">path_images</span> <span class="o">=</span> <span class="s2">"./101_ObjectCategories/airplanes/"</span> <span class="n">path_annot</span> <span class="o">=</span> <span class="s2">"./Annotations/Airplanes_Side_2/"</span> <span class="n">path_to_downloaded_file</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">get_file</span><span class="p">(</span> <span class="n">fname</span><span class="o">=</span><span class="s2">"caltech_101_zipped"</span><span class="p">,</span> <span class="n">origin</span><span class="o">=</span><span class="s2">"https://data.caltech.edu/records/mzrjq-6wc02/files/caltech-101.zip"</span><span class="p">,</span> <span class="n">extract</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">archive_format</span><span class="o">=</span><span class="s2">"zip"</span><span class="p">,</span> <span class="c1"># downloaded file format</span> <span class="n">cache_dir</span><span class="o">=</span><span class="s2">"/"</span><span class="p">,</span> <span class="c1"># cache and extract in current directory</span> <span class="p">)</span> <span class="n">download_base_dir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">dirname</span><span class="p">(</span><span class="n">path_to_downloaded_file</span><span class="p">)</span> <span class="c1"># Extracting tar files found inside main zip file</span> <span class="n">shutil</span><span class="o">.</span><span class="n">unpack_archive</span><span class="p">(</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">download_base_dir</span><span class="p">,</span> <span class="s2">"caltech-101"</span><span class="p">,</span> <span class="s2">"101_ObjectCategories.tar.gz"</span><span class="p">),</span> <span class="s2">"."</span> <span class="p">)</span> <span class="n">shutil</span><span class="o">.</span><span class="n">unpack_archive</span><span class="p">(</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">download_base_dir</span><span class="p">,</span> <span class="s2">"caltech-101"</span><span class="p">,</span> <span class="s2">"Annotations.tar"</span><span class="p">),</span> <span class="s2">"."</span> <span class="p">)</span> <span class="c1"># list of paths to images and annotations</span> <span class="n">image_paths</span> <span class="o">=</span> <span class="p">[</span> <span class="n">f</span> <span class="k">for</span> <span class="n">f</span> <span class="ow">in</span> <span class="n">os</span><span class="o">.</span><span class="n">listdir</span><span class="p">(</span><span class="n">path_images</span><span class="p">)</span> <span class="k">if</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">isfile</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">path_images</span><span class="p">,</span> <span class="n">f</span><span class="p">))</span> <span class="p">]</span> <span class="n">annot_paths</span> <span class="o">=</span> <span class="p">[</span> <span class="n">f</span> <span class="k">for</span> <span class="n">f</span> <span class="ow">in</span> <span class="n">os</span><span class="o">.</span><span class="n">listdir</span><span class="p">(</span><span class="n">path_annot</span><span class="p">)</span> <span class="k">if</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">isfile</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">path_annot</span><span class="p">,</span> <span class="n">f</span><span class="p">))</span> <span class="p">]</span> <span class="n">image_paths</span><span class="o">.</span><span class="n">sort</span><span class="p">()</span> <span class="n">annot_paths</span><span class="o">.</span><span class="n">sort</span><span class="p">()</span> <span class="n">image_size</span> <span class="o">=</span> <span class="mi">224</span> <span class="c1"># resize input images to this size</span> <span class="n">images</span><span class="p">,</span> <span class="n">targets</span> <span class="o">=</span> <span class="p">[],</span> <span class="p">[]</span> <span class="c1"># loop over the annotations and images, preprocess them and store in lists</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">annot_paths</span><span class="p">)):</span> <span class="c1"># Access bounding box coordinates</span> <span class="n">annot</span> <span class="o">=</span> <span class="n">scipy</span><span class="o">.</span><span class="n">io</span><span class="o">.</span><span class="n">loadmat</span><span class="p">(</span><span class="n">path_annot</span> <span class="o">+</span> <span class="n">annot_paths</span><span class="p">[</span><span class="n">i</span><span class="p">])[</span><span class="s2">"box_coord"</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span> <span class="n">top_left_x</span><span class="p">,</span> <span class="n">top_left_y</span> <span class="o">=</span> <span class="n">annot</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="n">annot</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="n">bottom_right_x</span><span class="p">,</span> <span class="n">bottom_right_y</span> <span class="o">=</span> <span class="n">annot</span><span class="p">[</span><span class="mi">3</span><span class="p">],</span> <span class="n">annot</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="n">image</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">load_img</span><span class="p">(</span> <span class="n">path_images</span> <span class="o">+</span> <span class="n">image_paths</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="p">)</span> <span class="p">(</span><span class="n">w</span><span class="p">,</span> <span class="n">h</span><span class="p">)</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">size</span><span class="p">[:</span><span class="mi">2</span><span class="p">]</span> <span class="c1"># resize images</span> <span class="n">image</span> <span class="o">=</span> <span class="n">image</span><span class="o">.</span><span class="n">resize</span><span class="p">((</span><span class="n">image_size</span><span class="p">,</span> <span class="n">image_size</span><span class="p">))</span> <span class="c1"># convert image to array and append to list</span> <span class="n">images</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">img_to_array</span><span class="p">(</span><span class="n">image</span><span class="p">))</span> <span class="c1"># apply relative scaling to bounding boxes as per given image and append to list</span> <span class="n">targets</span><span class="o">.</span><span class="n">append</span><span class="p">(</span> <span class="p">(</span> <span class="nb">float</span><span class="p">(</span><span class="n">top_left_x</span><span class="p">)</span> <span class="o">/</span> <span class="n">w</span><span class="p">,</span> <span class="nb">float</span><span class="p">(</span><span class="n">top_left_y</span><span class="p">)</span> <span class="o">/</span> <span class="n">h</span><span class="p">,</span> <span class="nb">float</span><span class="p">(</span><span class="n">bottom_right_x</span><span class="p">)</span> <span class="o">/</span> <span class="n">w</span><span class="p">,</span> <span class="nb">float</span><span class="p">(</span><span class="n">bottom_right_y</span><span class="p">)</span> <span class="o">/</span> <span class="n">h</span><span class="p">,</span> <span class="p">)</span> <span class="p">)</span> <span class="c1"># Convert the list to numpy array, split to train and test dataset</span> <span class="p">(</span><span class="n">x_train</span><span class="p">),</span> <span class="p">(</span><span class="n">y_train</span><span class="p">)</span> <span class="o">=</span> <span class="p">(</span> <span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">images</span><span class="p">[:</span> <span class="nb">int</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">images</span><span class="p">)</span> <span class="o">*</span> <span class="mf">0.8</span><span class="p">)]),</span> <span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">targets</span><span class="p">[:</span> <span class="nb">int</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">targets</span><span class="p">)</span> <span class="o">*</span> <span class="mf">0.8</span><span class="p">)]),</span> <span class="p">)</span> <span class="p">(</span><span class="n">x_test</span><span class="p">),</span> <span class="p">(</span><span class="n">y_test</span><span class="p">)</span> <span class="o">=</span> <span class="p">(</span> <span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">images</span><span class="p">[</span><span class="nb">int</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">images</span><span class="p">)</span> <span class="o">*</span> <span class="mf">0.8</span><span class="p">)</span> <span class="p">:]),</span> <span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">targets</span><span class="p">[</span><span class="nb">int</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">targets</span><span class="p">)</span> <span class="o">*</span> <span class="mf">0.8</span><span class="p">)</span> <span class="p">:]),</span> <span class="p">)</span> </code></pre></div> <hr /> <h2 id="implement-multilayerperceptron-mlp">Implement multilayer-perceptron (MLP)</h2> <p>We use the code from the Keras example <a href="https://keras.io/examples/vision/image_classification_with_vision_transformer/">Image classification with Vision Transformer</a> as a reference.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">mlp</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">hidden_units</span><span class="p">,</span> <span class="n">dropout_rate</span><span class="p">):</span> <span class="k">for</span> <span class="n">units</span> <span class="ow">in</span> <span class="n">hidden_units</span><span class="p">:</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">units</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">activations</span><span class="o">.</span><span class="n">gelu</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout_rate</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">x</span> </code></pre></div> <hr /> <h2 id="implement-the-patch-creation-layer">Implement the patch creation layer</h2> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">Patches</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">patch_size</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">patch_size</span> <span class="o">=</span> <span class="n">patch_size</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">images</span><span class="p">):</span> <span class="n">input_shape</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">images</span><span class="p">)</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="n">input_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="n">height</span> <span class="o">=</span> <span class="n">input_shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="n">width</span> <span class="o">=</span> <span class="n">input_shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="n">channels</span> <span class="o">=</span> <span class="n">input_shape</span><span class="p">[</span><span class="mi">3</span><span class="p">]</span> <span class="n">num_patches_h</span> <span class="o">=</span> <span class="n">height</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">patch_size</span> <span class="n">num_patches_w</span> <span class="o">=</span> <span class="n">width</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">patch_size</span> <span class="n">patches</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">extract_patches</span><span class="p">(</span><span class="n">images</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">patch_size</span><span class="p">)</span> <span class="n">patches</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span> <span class="n">patches</span><span class="p">,</span> <span class="p">(</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">num_patches_h</span> <span class="o">*</span> <span class="n">num_patches_w</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">patch_size</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">patch_size</span> <span class="o">*</span> <span class="n">channels</span><span class="p">,</span> <span class="p">),</span> <span class="p">)</span> <span class="k">return</span> <span class="n">patches</span> <span class="k">def</span> <span class="nf">get_config</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="n">config</span> <span class="o">=</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">get_config</span><span class="p">()</span> <span class="n">config</span><span class="o">.</span><span class="n">update</span><span class="p">({</span><span class="s2">"patch_size"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">patch_size</span><span class="p">})</span> <span class="k">return</span> <span class="n">config</span> </code></pre></div> <hr /> <h2 id="display-patches-for-an-input-image">Display patches for an input image</h2> <div class="codehilite"><pre><span></span><code><span class="n">patch_size</span> <span class="o">=</span> <span class="mi">32</span> <span class="c1"># Size of the patches to be extracted from the input images</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">))</span> <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">x_train</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">"uint8"</span><span class="p">))</span> <span class="n">plt</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">"off"</span><span class="p">)</span> <span class="n">patches</span> <span class="o">=</span> <span class="n">Patches</span><span class="p">(</span><span class="n">patch_size</span><span class="p">)(</span><span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">x_train</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">))</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Image size: </span><span class="si">{</span><span class="n">image_size</span><span class="si">}</span><span class="s2"> X </span><span class="si">{</span><span class="n">image_size</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Patch size: </span><span class="si">{</span><span class="n">patch_size</span><span class="si">}</span><span class="s2"> X </span><span class="si">{</span><span class="n">patch_size</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">patches</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="si">}</span><span class="s2"> patches per image </span><span class="se">\n</span><span class="si">{</span><span class="n">patches</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span><span class="si">}</span><span class="s2"> elements per patch"</span><span class="p">)</span> <span class="n">n</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">patches</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]))</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">4</span><span class="p">))</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">patch</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">patches</span><span class="p">[</span><span class="mi">0</span><span class="p">]):</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplot</span><span class="p">(</span><span class="n">n</span><span class="p">,</span> <span class="n">n</span><span class="p">,</span> <span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="n">patch_img</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">patch</span><span class="p">,</span> <span class="p">(</span><span class="n">patch_size</span><span class="p">,</span> <span class="n">patch_size</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span> <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">ops</span><span class="o">.</span><span class="n">convert_to_numpy</span><span class="p">(</span><span class="n">patch_img</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">"uint8"</span><span class="p">))</span> <span class="n">plt</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">"off"</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Image size: 224 X 224 Patch size: 32 X 32 49 patches per image 3072 elements per patch </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/object_detection_using_vision_transformer/object_detection_using_vision_transformer_11_1.png" /></p> <p><img alt="png" src="/img/examples/vision/object_detection_using_vision_transformer/object_detection_using_vision_transformer_11_2.png" /></p> <hr /> <h2 id="implement-the-patch-encoding-layer">Implement the patch encoding layer</h2> <p>The <code>PatchEncoder</code> layer linearly transforms a patch by projecting it into a vector of size <code>projection_dim</code>. It also adds a learnable position embedding to the projected vector.</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">PatchEncoder</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">num_patches</span><span class="p">,</span> <span class="n">projection_dim</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_patches</span> <span class="o">=</span> <span class="n">num_patches</span> <span class="bp">self</span><span class="o">.</span><span class="n">projection</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">units</span><span class="o">=</span><span class="n">projection_dim</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">position_embedding</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span> <span class="n">input_dim</span><span class="o">=</span><span class="n">num_patches</span><span class="p">,</span> <span class="n">output_dim</span><span class="o">=</span><span class="n">projection_dim</span> <span class="p">)</span> <span class="c1"># Override function to avoid error while saving model</span> <span class="k">def</span> <span class="nf">get_config</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="n">config</span> <span class="o">=</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">get_config</span><span class="p">()</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span> <span class="n">config</span><span class="o">.</span><span class="n">update</span><span class="p">(</span> <span class="p">{</span> <span class="s2">"input_shape"</span><span class="p">:</span> <span class="n">input_shape</span><span class="p">,</span> <span class="s2">"patch_size"</span><span class="p">:</span> <span class="n">patch_size</span><span class="p">,</span> <span class="s2">"num_patches"</span><span class="p">:</span> <span class="n">num_patches</span><span class="p">,</span> <span class="s2">"projection_dim"</span><span class="p">:</span> <span class="n">projection_dim</span><span class="p">,</span> <span class="s2">"num_heads"</span><span class="p">:</span> <span class="n">num_heads</span><span class="p">,</span> <span class="s2">"transformer_units"</span><span class="p">:</span> <span class="n">transformer_units</span><span class="p">,</span> <span class="s2">"transformer_layers"</span><span class="p">:</span> <span class="n">transformer_layers</span><span class="p">,</span> <span class="s2">"mlp_head_units"</span><span class="p">:</span> <span class="n">mlp_head_units</span><span class="p">,</span> <span class="p">}</span> <span class="p">)</span> <span class="k">return</span> <span class="n">config</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">patch</span><span class="p">):</span> <span class="n">positions</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span> <span class="n">ops</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">start</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">stop</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">num_patches</span><span class="p">,</span> <span class="n">step</span><span class="o">=</span><span class="mi">1</span><span class="p">),</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span> <span class="p">)</span> <span class="n">projected_patches</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">projection</span><span class="p">(</span><span class="n">patch</span><span class="p">)</span> <span class="n">encoded</span> <span class="o">=</span> <span class="n">projected_patches</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">position_embedding</span><span class="p">(</span><span class="n">positions</span><span class="p">)</span> <span class="k">return</span> <span class="n">encoded</span> </code></pre></div> <hr /> <h2 id="build-the-vit-model">Build the ViT model</h2> <p>The ViT model has multiple Transformer blocks. The <code>MultiHeadAttention</code> layer is used for self-attention, applied to the sequence of image patches. The encoded patches (skip connection) and self-attention layer outputs are normalized and fed into a multilayer perceptron (MLP). The model outputs four dimensions representing the bounding box coordinates of an object.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">create_vit_object_detector</span><span class="p">(</span> <span class="n">input_shape</span><span class="p">,</span> <span class="n">patch_size</span><span class="p">,</span> <span class="n">num_patches</span><span class="p">,</span> <span class="n">projection_dim</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">,</span> <span class="n">transformer_units</span><span class="p">,</span> <span class="n">transformer_layers</span><span class="p">,</span> <span class="n">mlp_head_units</span><span class="p">,</span> <span class="p">):</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="n">input_shape</span><span class="p">)</span> <span class="c1"># Create patches</span> <span class="n">patches</span> <span class="o">=</span> <span class="n">Patches</span><span class="p">(</span><span class="n">patch_size</span><span class="p">)(</span><span class="n">inputs</span><span class="p">)</span> <span class="c1"># Encode patches</span> <span class="n">encoded_patches</span> <span class="o">=</span> <span class="n">PatchEncoder</span><span class="p">(</span><span class="n">num_patches</span><span class="p">,</span> <span class="n">projection_dim</span><span class="p">)(</span><span class="n">patches</span><span class="p">)</span> <span class="c1"># Create multiple layers of the Transformer block.</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">transformer_layers</span><span class="p">):</span> <span class="c1"># Layer normalization 1.</span> <span class="n">x1</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">LayerNormalization</span><span class="p">(</span><span class="n">epsilon</span><span class="o">=</span><span class="mf">1e-6</span><span class="p">)(</span><span class="n">encoded_patches</span><span class="p">)</span> <span class="c1"># Create a multi-head attention layer.</span> <span class="n">attention_output</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">MultiHeadAttention</span><span class="p">(</span> <span class="n">num_heads</span><span class="o">=</span><span class="n">num_heads</span><span class="p">,</span> <span class="n">key_dim</span><span class="o">=</span><span class="n">projection_dim</span><span class="p">,</span> <span class="n">dropout</span><span class="o">=</span><span class="mf">0.1</span> <span class="p">)(</span><span class="n">x1</span><span class="p">,</span> <span class="n">x1</span><span class="p">)</span> <span class="c1"># Skip connection 1.</span> <span class="n">x2</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Add</span><span class="p">()([</span><span class="n">attention_output</span><span class="p">,</span> <span class="n">encoded_patches</span><span class="p">])</span> <span class="c1"># Layer normalization 2.</span> <span class="n">x3</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">LayerNormalization</span><span class="p">(</span><span class="n">epsilon</span><span class="o">=</span><span class="mf">1e-6</span><span class="p">)(</span><span class="n">x2</span><span class="p">)</span> <span class="c1"># MLP</span> <span class="n">x3</span> <span class="o">=</span> <span class="n">mlp</span><span class="p">(</span><span class="n">x3</span><span class="p">,</span> <span class="n">hidden_units</span><span class="o">=</span><span class="n">transformer_units</span><span class="p">,</span> <span class="n">dropout_rate</span><span class="o">=</span><span class="mf">0.1</span><span class="p">)</span> <span class="c1"># Skip connection 2.</span> <span class="n">encoded_patches</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Add</span><span class="p">()([</span><span class="n">x3</span><span class="p">,</span> <span class="n">x2</span><span class="p">])</span> <span class="c1"># Create a [batch_size, projection_dim] tensor.</span> <span class="n">representation</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">LayerNormalization</span><span class="p">(</span><span class="n">epsilon</span><span class="o">=</span><span class="mf">1e-6</span><span class="p">)(</span><span class="n">encoded_patches</span><span class="p">)</span> <span class="n">representation</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Flatten</span><span class="p">()(</span><span class="n">representation</span><span class="p">)</span> <span class="n">representation</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="mf">0.3</span><span class="p">)(</span><span class="n">representation</span><span class="p">)</span> <span class="c1"># Add MLP.</span> <span class="n">features</span> <span class="o">=</span> <span class="n">mlp</span><span class="p">(</span><span class="n">representation</span><span class="p">,</span> <span class="n">hidden_units</span><span class="o">=</span><span class="n">mlp_head_units</span><span class="p">,</span> <span class="n">dropout_rate</span><span class="o">=</span><span class="mf">0.3</span><span class="p">)</span> <span class="n">bounding_box</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">4</span><span class="p">)(</span> <span class="n">features</span> <span class="p">)</span> <span class="c1"># Final four neurons that output bounding box</span> <span class="c1"># return Keras model.</span> <span class="k">return</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="o">=</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="o">=</span><span class="n">bounding_box</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="run-the-experiment">Run the experiment</h2> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">run_experiment</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">learning_rate</span><span class="p">,</span> <span class="n">weight_decay</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">num_epochs</span><span class="p">):</span> <span class="n">optimizer</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">AdamW</span><span class="p">(</span> <span class="n">learning_rate</span><span class="o">=</span><span class="n">learning_rate</span><span class="p">,</span> <span class="n">weight_decay</span><span class="o">=</span><span class="n">weight_decay</span> <span class="p">)</span> <span class="c1"># Compile model.</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">optimizer</span><span class="o">=</span><span class="n">optimizer</span><span class="p">,</span> <span class="n">loss</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">MeanSquaredError</span><span class="p">())</span> <span class="n">checkpoint_filepath</span> <span class="o">=</span> <span class="s2">"vit_object_detector.weights.h5"</span> <span class="n">checkpoint_callback</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">callbacks</span><span class="o">.</span><span class="n">ModelCheckpoint</span><span class="p">(</span> <span class="n">checkpoint_filepath</span><span class="p">,</span> <span class="n">monitor</span><span class="o">=</span><span class="s2">"val_loss"</span><span class="p">,</span> <span class="n">save_best_only</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">save_weights_only</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="p">)</span> <span class="n">history</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span> <span class="n">x</span><span class="o">=</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y</span><span class="o">=</span><span class="n">y_train</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="n">num_epochs</span><span class="p">,</span> <span class="n">validation_split</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">callbacks</span><span class="o">=</span><span class="p">[</span> <span class="n">checkpoint_callback</span><span class="p">,</span> <span class="n">keras</span><span class="o">.</span><span class="n">callbacks</span><span class="o">.</span><span class="n">EarlyStopping</span><span class="p">(</span><span class="n">monitor</span><span class="o">=</span><span class="s2">"val_loss"</span><span class="p">,</span> <span class="n">patience</span><span class="o">=</span><span class="mi">10</span><span class="p">),</span> <span class="p">],</span> <span class="p">)</span> <span class="k">return</span> <span class="n">history</span> <span class="n">input_shape</span> <span class="o">=</span> <span class="p">(</span><span class="n">image_size</span><span class="p">,</span> <span class="n">image_size</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span> <span class="c1"># input image shape</span> <span class="n">learning_rate</span> <span class="o">=</span> <span class="mf">0.001</span> <span class="n">weight_decay</span> <span class="o">=</span> <span class="mf">0.0001</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="mi">32</span> <span class="n">num_epochs</span> <span class="o">=</span> <span class="mi">100</span> <span class="n">num_patches</span> <span class="o">=</span> <span class="p">(</span><span class="n">image_size</span> <span class="o">//</span> <span class="n">patch_size</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span> <span class="n">projection_dim</span> <span class="o">=</span> <span class="mi">64</span> <span class="n">num_heads</span> <span class="o">=</span> <span class="mi">4</span> <span class="c1"># Size of the transformer layers</span> <span class="n">transformer_units</span> <span class="o">=</span> <span class="p">[</span> <span class="n">projection_dim</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span> <span class="n">projection_dim</span><span class="p">,</span> <span class="p">]</span> <span class="n">transformer_layers</span> <span class="o">=</span> <span class="mi">4</span> <span class="n">mlp_head_units</span> <span class="o">=</span> <span class="p">[</span><span class="mi">2048</span><span class="p">,</span> <span class="mi">1024</span><span class="p">,</span> <span class="mi">512</span><span class="p">,</span> <span class="mi">64</span><span class="p">,</span> <span class="mi">32</span><span class="p">]</span> <span class="c1"># Size of the dense layers</span> <span class="n">history</span> <span class="o">=</span> <span class="p">[]</span> <span class="n">num_patches</span> <span class="o">=</span> <span class="p">(</span><span class="n">image_size</span> <span class="o">//</span> <span class="n">patch_size</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span> <span class="n">vit_object_detector</span> <span class="o">=</span> <span class="n">create_vit_object_detector</span><span class="p">(</span> <span class="n">input_shape</span><span class="p">,</span> <span class="n">patch_size</span><span class="p">,</span> <span class="n">num_patches</span><span class="p">,</span> <span class="n">projection_dim</span><span class="p">,</span> <span class="n">num_heads</span><span class="p">,</span> <span class="n">transformer_units</span><span class="p">,</span> <span class="n">transformer_layers</span><span class="p">,</span> <span class="n">mlp_head_units</span><span class="p">,</span> <span class="p">)</span> <span class="c1"># Train model</span> <span class="n">history</span> <span class="o">=</span> <span class="n">run_experiment</span><span class="p">(</span> <span class="n">vit_object_detector</span><span class="p">,</span> <span class="n">learning_rate</span><span class="p">,</span> <span class="n">weight_decay</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">,</span> <span class="n">num_epochs</span> <span class="p">)</span> <span class="k">def</span> <span class="nf">plot_history</span><span class="p">(</span><span class="n">item</span><span class="p">):</span> <span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">history</span><span class="o">.</span><span class="n">history</span><span class="p">[</span><span class="n">item</span><span class="p">],</span> <span class="n">label</span><span class="o">=</span><span class="n">item</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">history</span><span class="o">.</span><span class="n">history</span><span class="p">[</span><span class="s2">"val_"</span> <span class="o">+</span> <span class="n">item</span><span class="p">],</span> <span class="n">label</span><span class="o">=</span><span class="s2">"val_"</span> <span class="o">+</span> <span class="n">item</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s2">"Epochs"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">ylabel</span><span class="p">(</span><span class="n">item</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="s2">"Train and Validation </span><span class="si">{}</span><span class="s2"> Over Epochs"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">item</span><span class="p">),</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">14</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">legend</span><span class="p">()</span> <span class="n">plt</span><span class="o">.</span><span class="n">grid</span><span class="p">()</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> <span class="n">plot_history</span><span class="p">(</span><span class="s2">"loss"</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 1/100 18/18 ━━━━━━━━━━━━━━━━━━━━ 9s 109ms/step - loss: 1.2097 - val_loss: 0.3468 Epoch 2/100 18/18 ━━━━━━━━━━━━━━━━━━━━ 0s 25ms/step - loss: 0.4260 - val_loss: 0.3102 Epoch 3/100 18/18 ━━━━━━━━━━━━━━━━━━━━ 0s 25ms/step - loss: 0.3268 - val_loss: 0.2727 Epoch 4/100 18/18 ━━━━━━━━━━━━━━━━━━━━ 0s 25ms/step - loss: 0.2815 - val_loss: 0.2391 Epoch 5/100 18/18 ━━━━━━━━━━━━━━━━━━━━ 0s 25ms/step - loss: 0.2290 - val_loss: 0.1735 Epoch 6/100 18/18 ━━━━━━━━━━━━━━━━━━━━ 0s 24ms/step - loss: 0.1870 - val_loss: 0.1055 Epoch 7/100 18/18 ━━━━━━━━━━━━━━━━━━━━ 0s 25ms/step - loss: 0.1401 - val_loss: 0.0610 Epoch 8/100 18/18 ━━━━━━━━━━━━━━━━━━━━ 0s 25ms/step - loss: 0.1122 - val_loss: 0.0274 Epoch 9/100 18/18 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0924 - val_loss: 0.0296 Epoch 10/100 18/18 ━━━━━━━━━━━━━━━━━━━━ 0s 24ms/step - loss: 0.0765 - val_loss: 0.0139 Epoch 11/100 18/18 ━━━━━━━━━━━━━━━━━━━━ 0s 25ms/step - loss: 0.0597 - val_loss: 0.0111 Epoch 12/100 18/18 ━━━━━━━━━━━━━━━━━━━━ 0s 25ms/step - loss: 0.0540 - val_loss: 0.0101 Epoch 13/100 18/18 ━━━━━━━━━━━━━━━━━━━━ 0s 24ms/step - loss: 0.0432 - val_loss: 0.0053 Epoch 14/100 18/18 ━━━━━━━━━━━━━━━━━━━━ 0s 24ms/step - loss: 0.0380 - val_loss: 0.0052 Epoch 15/100 18/18 ━━━━━━━━━━━━━━━━━━━━ 0s 25ms/step - loss: 0.0334 - val_loss: 0.0030 Epoch 16/100 18/18 ━━━━━━━━━━━━━━━━━━━━ 0s 25ms/step - loss: 0.0283 - val_loss: 0.0021 Epoch 17/100 18/18 ━━━━━━━━━━━━━━━━━━━━ 0s 24ms/step - loss: 0.0228 - val_loss: 0.0012 Epoch 18/100 18/18 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0244 - val_loss: 0.0017 Epoch 19/100 18/18 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0195 - val_loss: 0.0016 Epoch 20/100 18/18 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0189 - val_loss: 0.0020 Epoch 21/100 18/18 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0191 - val_loss: 0.0019 Epoch 22/100 18/18 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0174 - val_loss: 0.0016 Epoch 23/100 18/18 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0157 - val_loss: 0.0020 Epoch 24/100 18/18 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0157 - val_loss: 0.0015 Epoch 25/100 18/18 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0139 - val_loss: 0.0023 Epoch 26/100 18/18 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0130 - val_loss: 0.0017 Epoch 27/100 18/18 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - loss: 0.0157 - val_loss: 0.0014 </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/object_detection_using_vision_transformer/object_detection_using_vision_transformer_17_1.png" /></p> <hr /> <h2 id="evaluate-the-model">Evaluate the model</h2> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">matplotlib.patches</span> <span class="k">as</span> <span class="nn">patches</span> <span class="c1"># Saves the model in current path</span> <span class="n">vit_object_detector</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="s2">"vit_object_detector.keras"</span><span class="p">)</span> <span class="c1"># To calculate IoU (intersection over union, given two bounding boxes)</span> <span class="k">def</span> <span class="nf">bounding_box_intersection_over_union</span><span class="p">(</span><span class="n">box_predicted</span><span class="p">,</span> <span class="n">box_truth</span><span class="p">):</span> <span class="c1"># get (x, y) coordinates of intersection of bounding boxes</span> <span class="n">top_x_intersect</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">box_predicted</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">box_truth</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="n">top_y_intersect</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="n">box_predicted</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">box_truth</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span> <span class="n">bottom_x_intersect</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">box_predicted</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="n">box_truth</span><span class="p">[</span><span class="mi">2</span><span class="p">])</span> <span class="n">bottom_y_intersect</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">box_predicted</span><span class="p">[</span><span class="mi">3</span><span class="p">],</span> <span class="n">box_truth</span><span class="p">[</span><span class="mi">3</span><span class="p">])</span> <span class="c1"># calculate area of the intersection bb (bounding box)</span> <span class="n">intersection_area</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">bottom_x_intersect</span> <span class="o">-</span> <span class="n">top_x_intersect</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="nb">max</span><span class="p">(</span> <span class="mi">0</span><span class="p">,</span> <span class="n">bottom_y_intersect</span> <span class="o">-</span> <span class="n">top_y_intersect</span> <span class="o">+</span> <span class="mi">1</span> <span class="p">)</span> <span class="c1"># calculate area of the prediction bb and ground-truth bb</span> <span class="n">box_predicted_area</span> <span class="o">=</span> <span class="p">(</span><span class="n">box_predicted</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">-</span> <span class="n">box_predicted</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span> <span class="n">box_predicted</span><span class="p">[</span><span class="mi">3</span><span class="p">]</span> <span class="o">-</span> <span class="n">box_predicted</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="mi">1</span> <span class="p">)</span> <span class="n">box_truth_area</span> <span class="o">=</span> <span class="p">(</span><span class="n">box_truth</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">-</span> <span class="n">box_truth</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span> <span class="n">box_truth</span><span class="p">[</span><span class="mi">3</span><span class="p">]</span> <span class="o">-</span> <span class="n">box_truth</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="mi">1</span> <span class="p">)</span> <span class="c1"># calculate intersection over union by taking intersection</span> <span class="c1"># area and dividing it by the sum of predicted bb and ground truth</span> <span class="c1"># bb areas subtracted by the interesection area</span> <span class="c1"># return ioU</span> <span class="k">return</span> <span class="n">intersection_area</span> <span class="o">/</span> <span class="nb">float</span><span class="p">(</span> <span class="n">box_predicted_area</span> <span class="o">+</span> <span class="n">box_truth_area</span> <span class="o">-</span> <span class="n">intersection_area</span> <span class="p">)</span> <span class="n">i</span><span class="p">,</span> <span class="n">mean_iou</span> <span class="o">=</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span> <span class="c1"># Compare results for 10 images in the test set</span> <span class="k">for</span> <span class="n">input_image</span> <span class="ow">in</span> <span class="n">x_test</span><span class="p">[:</span><span class="mi">10</span><span class="p">]:</span> <span class="n">fig</span><span class="p">,</span> <span class="p">(</span><span class="n">ax1</span><span class="p">,</span> <span class="n">ax2</span><span class="p">)</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">15</span><span class="p">,</span> <span class="mi">15</span><span class="p">))</span> <span class="n">im</span> <span class="o">=</span> <span class="n">input_image</span> <span class="c1"># Display the image</span> <span class="n">ax1</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">im</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">"uint8"</span><span class="p">))</span> <span class="n">ax2</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">im</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">"uint8"</span><span class="p">))</span> <span class="n">input_image</span> <span class="o">=</span> <span class="n">cv2</span><span class="o">.</span><span class="n">resize</span><span class="p">(</span> <span class="n">input_image</span><span class="p">,</span> <span class="p">(</span><span class="n">image_size</span><span class="p">,</span> <span class="n">image_size</span><span class="p">),</span> <span class="n">interpolation</span><span class="o">=</span><span class="n">cv2</span><span class="o">.</span><span class="n">INTER_AREA</span> <span class="p">)</span> <span class="n">input_image</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">input_image</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="n">preds</span> <span class="o">=</span> <span class="n">vit_object_detector</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">input_image</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span> <span class="p">(</span><span class="n">h</span><span class="p">,</span> <span class="n">w</span><span class="p">)</span> <span class="o">=</span> <span class="p">(</span><span class="n">im</span><span class="p">)</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">:</span><span class="mi">2</span><span class="p">]</span> <span class="n">top_left_x</span><span class="p">,</span> <span class="n">top_left_y</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">preds</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">w</span><span class="p">),</span> <span class="nb">int</span><span class="p">(</span><span class="n">preds</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">h</span><span class="p">)</span> <span class="n">bottom_right_x</span><span class="p">,</span> <span class="n">bottom_right_y</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">preds</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">*</span> <span class="n">w</span><span class="p">),</span> <span class="nb">int</span><span class="p">(</span><span class="n">preds</span><span class="p">[</span><span class="mi">3</span><span class="p">]</span> <span class="o">*</span> <span class="n">h</span><span class="p">)</span> <span class="n">box_predicted</span> <span class="o">=</span> <span class="p">[</span><span class="n">top_left_x</span><span class="p">,</span> <span class="n">top_left_y</span><span class="p">,</span> <span class="n">bottom_right_x</span><span class="p">,</span> <span class="n">bottom_right_y</span><span class="p">]</span> <span class="c1"># Create the bounding box</span> <span class="n">rect</span> <span class="o">=</span> <span class="n">patches</span><span class="o">.</span><span class="n">Rectangle</span><span class="p">(</span> <span class="p">(</span><span class="n">top_left_x</span><span class="p">,</span> <span class="n">top_left_y</span><span class="p">),</span> <span class="n">bottom_right_x</span> <span class="o">-</span> <span class="n">top_left_x</span><span class="p">,</span> <span class="n">bottom_right_y</span> <span class="o">-</span> <span class="n">top_left_y</span><span class="p">,</span> <span class="n">facecolor</span><span class="o">=</span><span class="s2">"none"</span><span class="p">,</span> <span class="n">edgecolor</span><span class="o">=</span><span class="s2">"red"</span><span class="p">,</span> <span class="n">linewidth</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="p">)</span> <span class="c1"># Add the bounding box to the image</span> <span class="n">ax1</span><span class="o">.</span><span class="n">add_patch</span><span class="p">(</span><span class="n">rect</span><span class="p">)</span> <span class="n">ax1</span><span class="o">.</span><span class="n">set_xlabel</span><span class="p">(</span> <span class="s2">"Predicted: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">top_left_x</span><span class="p">)</span> <span class="o">+</span> <span class="s2">", "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">top_left_y</span><span class="p">)</span> <span class="o">+</span> <span class="s2">", "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">bottom_right_x</span><span class="p">)</span> <span class="o">+</span> <span class="s2">", "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">bottom_right_y</span><span class="p">)</span> <span class="p">)</span> <span class="n">top_left_x</span><span class="p">,</span> <span class="n">top_left_y</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">y_test</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">w</span><span class="p">),</span> <span class="nb">int</span><span class="p">(</span><span class="n">y_test</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">h</span><span class="p">)</span> <span class="n">bottom_right_x</span><span class="p">,</span> <span class="n">bottom_right_y</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">y_test</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">2</span><span class="p">]</span> <span class="o">*</span> <span class="n">w</span><span class="p">),</span> <span class="nb">int</span><span class="p">(</span><span class="n">y_test</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="mi">3</span><span class="p">]</span> <span class="o">*</span> <span class="n">h</span><span class="p">)</span> <span class="n">box_truth</span> <span class="o">=</span> <span class="n">top_left_x</span><span class="p">,</span> <span class="n">top_left_y</span><span class="p">,</span> <span class="n">bottom_right_x</span><span class="p">,</span> <span class="n">bottom_right_y</span> <span class="n">mean_iou</span> <span class="o">+=</span> <span class="n">bounding_box_intersection_over_union</span><span class="p">(</span><span class="n">box_predicted</span><span class="p">,</span> <span class="n">box_truth</span><span class="p">)</span> <span class="c1"># Create the bounding box</span> <span class="n">rect</span> <span class="o">=</span> <span class="n">patches</span><span class="o">.</span><span class="n">Rectangle</span><span class="p">(</span> <span class="p">(</span><span class="n">top_left_x</span><span class="p">,</span> <span class="n">top_left_y</span><span class="p">),</span> <span class="n">bottom_right_x</span> <span class="o">-</span> <span class="n">top_left_x</span><span class="p">,</span> <span class="n">bottom_right_y</span> <span class="o">-</span> <span class="n">top_left_y</span><span class="p">,</span> <span class="n">facecolor</span><span class="o">=</span><span class="s2">"none"</span><span class="p">,</span> <span class="n">edgecolor</span><span class="o">=</span><span class="s2">"red"</span><span class="p">,</span> <span class="n">linewidth</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="p">)</span> <span class="c1"># Add the bounding box to the image</span> <span class="n">ax2</span><span class="o">.</span><span class="n">add_patch</span><span class="p">(</span><span class="n">rect</span><span class="p">)</span> <span class="n">ax2</span><span class="o">.</span><span class="n">set_xlabel</span><span class="p">(</span> <span class="s2">"Target: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">top_left_x</span><span class="p">)</span> <span class="o">+</span> <span class="s2">", "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">top_left_y</span><span class="p">)</span> <span class="o">+</span> <span class="s2">", "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">bottom_right_x</span><span class="p">)</span> <span class="o">+</span> <span class="s2">", "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">bottom_right_y</span><span class="p">)</span> <span class="o">+</span> <span class="s2">"</span><span class="se">\n</span><span class="s2">"</span> <span class="o">+</span> <span class="s2">"IoU"</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">bounding_box_intersection_over_union</span><span class="p">(</span><span class="n">box_predicted</span><span class="p">,</span> <span class="n">box_truth</span><span class="p">))</span> <span class="p">)</span> <span class="n">i</span> <span class="o">=</span> <span class="n">i</span> <span class="o">+</span> <span class="mi">1</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"mean_iou: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">mean_iou</span> <span class="o">/</span> <span class="nb">len</span><span class="p">(</span><span class="n">x_test</span><span class="p">[:</span><span class="mi">10</span><span class="p">])))</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 1s/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step mean_iou: 0.9092338486331416 </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/object_detection_using_vision_transformer/object_detection_using_vision_transformer_19_1.png" /></p> <p><img alt="png" src="/img/examples/vision/object_detection_using_vision_transformer/object_detection_using_vision_transformer_19_2.png" /></p> <p><img alt="png" src="/img/examples/vision/object_detection_using_vision_transformer/object_detection_using_vision_transformer_19_3.png" /></p> <p><img alt="png" src="/img/examples/vision/object_detection_using_vision_transformer/object_detection_using_vision_transformer_19_4.png" /></p> <p><img alt="png" src="/img/examples/vision/object_detection_using_vision_transformer/object_detection_using_vision_transformer_19_5.png" /></p> <p><img alt="png" src="/img/examples/vision/object_detection_using_vision_transformer/object_detection_using_vision_transformer_19_6.png" /></p> <p><img alt="png" src="/img/examples/vision/object_detection_using_vision_transformer/object_detection_using_vision_transformer_19_7.png" /></p> <p><img alt="png" src="/img/examples/vision/object_detection_using_vision_transformer/object_detection_using_vision_transformer_19_8.png" /></p> <p><img alt="png" src="/img/examples/vision/object_detection_using_vision_transformer/object_detection_using_vision_transformer_19_9.png" /></p> <p><img alt="png" src="/img/examples/vision/object_detection_using_vision_transformer/object_detection_using_vision_transformer_19_10.png" /></p> <p>This example demonstrates that a pure Transformer can be trained to predict the bounding boxes of an object in a given image, thus extending the use of Transformers to object detection tasks. The model can be improved further by tuning hyper-parameters and pre-training.</p> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#object-detection-with-vision-transformers'>Object detection with Vision Transformers</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#imports-and-setup'>Imports and setup</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#prepare-dataset'>Prepare dataset</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#implement-multilayerperceptron-mlp'>Implement multilayer-perceptron (MLP)</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#implement-the-patch-creation-layer'>Implement the patch creation layer</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#display-patches-for-an-input-image'>Display patches for an input image</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#implement-the-patch-encoding-layer'>Implement the patch encoding layer</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#build-the-vit-model'>Build the ViT model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#run-the-experiment'>Run the experiment</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#evaluate-the-model'>Evaluate the model</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>