CINXE.COM
Keypoint Detection with Transfer Learning
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/vision/keypoint_detection/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Keypoint Detection with Transfer Learning"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Keypoint Detection with Transfer Learning"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Keypoint Detection with Transfer Learning</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink active" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink2" href="/examples/vision/image_classification_from_scratch/">Image classification from scratch</a> <a class="nav-sublink2" href="/examples/vision/mnist_convnet/">Simple MNIST convnet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_efficientnet_fine_tuning/">Image classification via fine-tuning with EfficientNet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_with_vision_transformer/">Image classification with Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/attention_mil_classification/">Classification using Attention-based Deep Multiple Instance Learning</a> <a class="nav-sublink2" href="/examples/vision/mlp_image_classification/">Image classification with modern MLP models</a> <a class="nav-sublink2" href="/examples/vision/mobilevit/">A mobile-friendly Transformer-based model for image classification</a> <a class="nav-sublink2" href="/examples/vision/xray_classification_with_tpus/">Pneumonia Classification on TPU</a> <a class="nav-sublink2" href="/examples/vision/cct/">Compact Convolutional Transformers</a> <a class="nav-sublink2" href="/examples/vision/convmixer/">Image classification with ConvMixer</a> <a class="nav-sublink2" href="/examples/vision/eanet/">Image classification with EANet (External Attention Transformer)</a> <a class="nav-sublink2" href="/examples/vision/involution/">Involutional neural networks</a> <a class="nav-sublink2" href="/examples/vision/perceiver_image_classification/">Image classification with Perceiver</a> <a class="nav-sublink2" href="/examples/vision/reptile/">Few-Shot learning with Reptile</a> <a class="nav-sublink2" href="/examples/vision/semisupervised_simclr/">Semi-supervised image classification using contrastive pretraining with SimCLR</a> <a class="nav-sublink2" href="/examples/vision/swin_transformers/">Image classification with Swin Transformers</a> <a class="nav-sublink2" href="/examples/vision/vit_small_ds/">Train a Vision Transformer on small datasets</a> <a class="nav-sublink2" href="/examples/vision/shiftvit/">A Vision Transformer without Attention</a> <a class="nav-sublink2" href="/examples/vision/image_classification_using_global_context_vision_transformer/">Image Classification using Global Context Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/temporal_latent_bottleneck/">When Recurrence meets Transformers</a> <a class="nav-sublink2" href="/examples/vision/oxford_pets_image_segmentation/">Image segmentation with a U-Net-like architecture</a> <a class="nav-sublink2" href="/examples/vision/deeplabv3_plus/">Multiclass semantic segmentation using DeepLabV3+</a> <a class="nav-sublink2" href="/examples/vision/basnet_segmentation/">Highly accurate boundaries segmentation using BASNet</a> <a class="nav-sublink2" href="/examples/vision/fully_convolutional_network/">Image Segmentation using Composable Fully-Convolutional Networks</a> <a class="nav-sublink2" href="/examples/vision/retinanet/">Object Detection with RetinaNet</a> <a class="nav-sublink2 active" href="/examples/vision/keypoint_detection/">Keypoint Detection with Transfer Learning</a> <a class="nav-sublink2" href="/examples/vision/object_detection_using_vision_transformer/">Object detection with Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/3D_image_classification/">3D image classification from CT scans</a> <a class="nav-sublink2" href="/examples/vision/depth_estimation/">Monocular depth estimation</a> <a class="nav-sublink2" href="/examples/vision/nerf/">3D volumetric rendering with NeRF</a> <a class="nav-sublink2" href="/examples/vision/pointnet_segmentation/">Point cloud segmentation with PointNet</a> <a class="nav-sublink2" href="/examples/vision/pointnet/">Point cloud classification</a> <a class="nav-sublink2" href="/examples/vision/captcha_ocr/">OCR model for reading Captchas</a> <a class="nav-sublink2" href="/examples/vision/handwriting_recognition/">Handwriting recognition</a> <a class="nav-sublink2" href="/examples/vision/autoencoder/">Convolutional autoencoder for image denoising</a> <a class="nav-sublink2" href="/examples/vision/mirnet/">Low-light image enhancement using MIRNet</a> <a class="nav-sublink2" href="/examples/vision/super_resolution_sub_pixel/">Image Super-Resolution using an Efficient Sub-Pixel CNN</a> <a class="nav-sublink2" href="/examples/vision/edsr/">Enhanced Deep Residual Networks for single-image super-resolution</a> <a class="nav-sublink2" href="/examples/vision/zero_dce/">Zero-DCE for low-light image enhancement</a> <a class="nav-sublink2" href="/examples/vision/cutmix/">CutMix data augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/mixup/">MixUp augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/randaugment/">RandAugment for Image Classification for Improved Robustness</a> <a class="nav-sublink2" href="/examples/vision/image_captioning/">Image captioning</a> <a class="nav-sublink2" href="/examples/vision/nl_image_search/">Natural language image search with a Dual Encoder</a> <a class="nav-sublink2" href="/examples/vision/visualizing_what_convnets_learn/">Visualizing what convnets learn</a> <a class="nav-sublink2" href="/examples/vision/integrated_gradients/">Model interpretability with Integrated Gradients</a> <a class="nav-sublink2" href="/examples/vision/probing_vits/">Investigating Vision Transformer representations</a> <a class="nav-sublink2" href="/examples/vision/grad_cam/">Grad-CAM class activation visualization</a> <a class="nav-sublink2" href="/examples/vision/near_dup_search/">Near-duplicate image search</a> <a class="nav-sublink2" href="/examples/vision/semantic_image_clustering/">Semantic Image Clustering</a> <a class="nav-sublink2" href="/examples/vision/siamese_contrastive/">Image similarity estimation using a Siamese Network with a contrastive loss</a> <a class="nav-sublink2" href="/examples/vision/siamese_network/">Image similarity estimation using a Siamese Network with a triplet loss</a> <a class="nav-sublink2" href="/examples/vision/metric_learning/">Metric learning for image similarity search</a> <a class="nav-sublink2" href="/examples/vision/metric_learning_tf_similarity/">Metric learning for image similarity search using TensorFlow Similarity</a> <a class="nav-sublink2" href="/examples/vision/nnclr/">Self-supervised contrastive learning with NNCLR</a> <a class="nav-sublink2" href="/examples/vision/video_classification/">Video Classification with a CNN-RNN Architecture</a> <a class="nav-sublink2" href="/examples/vision/conv_lstm/">Next-Frame Video Prediction with Convolutional LSTMs</a> <a class="nav-sublink2" href="/examples/vision/video_transformers/">Video Classification with Transformers</a> <a class="nav-sublink2" href="/examples/vision/vivit/">Video Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/bit/">Image Classification using BigTransfer (BiT)</a> <a class="nav-sublink2" href="/examples/vision/gradient_centralization/">Gradient Centralization for Better Training Performance</a> <a class="nav-sublink2" href="/examples/vision/token_learner/">Learning to tokenize in Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/knowledge_distillation/">Knowledge Distillation</a> <a class="nav-sublink2" href="/examples/vision/fixres/">FixRes: Fixing train-test resolution discrepancy</a> <a class="nav-sublink2" href="/examples/vision/cait/">Class Attention Image Transformers with LayerScale</a> <a class="nav-sublink2" href="/examples/vision/patch_convnet/">Augmenting convnets with aggregated attention</a> <a class="nav-sublink2" href="/examples/vision/learnable_resizer/">Learning to Resize</a> <a class="nav-sublink2" href="/examples/vision/adamatch/">Semi-supervision and domain adaptation with AdaMatch</a> <a class="nav-sublink2" href="/examples/vision/barlow_twins/">Barlow Twins for Contrastive SSL</a> <a class="nav-sublink2" href="/examples/vision/consistency_training/">Consistency training with supervision</a> <a class="nav-sublink2" href="/examples/vision/deit/">Distilling Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/focal_modulation_network/">Focal Modulation: A replacement for Self-Attention</a> <a class="nav-sublink2" href="/examples/vision/forwardforward/">Using the Forward-Forward Algorithm for Image Classification</a> <a class="nav-sublink2" href="/examples/vision/masked_image_modeling/">Masked image modeling with Autoencoders</a> <a class="nav-sublink2" href="/examples/vision/sam/">Segment Anything Model with 🤗Transformers</a> <a class="nav-sublink2" href="/examples/vision/segformer/">Semantic segmentation with SegFormer and Hugging Face Transformers</a> <a class="nav-sublink2" href="/examples/vision/simsiam/">Self-supervised contrastive learning with SimSiam</a> <a class="nav-sublink2" href="/examples/vision/supervised-contrastive-learning/">Supervised Contrastive Learning</a> <a class="nav-sublink2" href="/examples/vision/yolov8/">Efficient Object Detection with YOLOV8 and KerasCV</a> <a class="nav-sublink" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparam Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/vision/'>Computer Vision</a> / Keypoint Detection with Transfer Learning </div> <div class='k-content'> <h1 id="keypoint-detection-with-transfer-learning">Keypoint Detection with Transfer Learning</h1> <p><strong>Author:</strong> <a href="https://twitter.com/RisingSayak">Sayak Paul</a>, converted to Keras 3 by <a href="https://anasrz.com">Muhammad Anas Raza</a><br> <strong>Date created:</strong> 2021/05/02<br> <strong>Last modified:</strong> 2023/07/19<br> <strong>Description:</strong> Training a keypoint detector with data augmentation and transfer learning.</p> <div class='example_version_banner keras_3'>ⓘ This example uses Keras 3</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/vision/ipynb/keypoint_detection.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/examples/vision/keypoint_detection.py"><strong>GitHub source</strong></a></p> <p>Keypoint detection consists of locating key object parts. For example, the key parts of our faces include nose tips, eyebrows, eye corners, and so on. These parts help to represent the underlying object in a feature-rich manner. Keypoint detection has applications that include pose estimation, face detection, etc.</p> <p>In this example, we will build a keypoint detector using the <a href="https://github.com/benjiebob/StanfordExtra">StanfordExtra dataset</a>, using transfer learning. This example requires TensorFlow 2.4 or higher, as well as <a href="https://imgaug.readthedocs.io/"><code>imgaug</code></a> library, which can be installed using the following command:</p> <div class="codehilite"><pre><span></span><code><span class="err">!</span><span class="n">pip</span> <span class="n">install</span> <span class="o">-</span><span class="n">q</span> <span class="o">-</span><span class="n">U</span> <span class="n">imgaug</span> </code></pre></div> <hr /> <h2 id="data-collection">Data collection</h2> <p>The StanfordExtra dataset contains 12,000 images of dogs together with keypoints and segmentation maps. It is developed from the <a href="http://vision.stanford.edu/aditya86/ImageNetDogs/">Stanford dogs dataset</a>. It can be downloaded with the command below:</p> <div class="codehilite"><pre><span></span><code><span class="err">!</span><span class="n">wget</span> <span class="o">-</span><span class="n">q</span> <span class="n">http</span><span class="p">:</span><span class="o">//</span><span class="n">vision</span><span class="o">.</span><span class="n">stanford</span><span class="o">.</span><span class="n">edu</span><span class="o">/</span><span class="n">aditya86</span><span class="o">/</span><span class="n">ImageNetDogs</span><span class="o">/</span><span class="n">images</span><span class="o">.</span><span class="n">tar</span> </code></pre></div> <p>Annotations are provided as a single JSON file in the StanfordExtra dataset and one needs to fill <a href="https://forms.gle/sRtbicgxsWvRtRmUA">this form</a> to get access to it. The authors explicitly instruct users not to share the JSON file, and this example respects this wish: you should obtain the JSON file yourself.</p> <p>The JSON file is expected to be locally available as <code>stanfordextra_v12.zip</code>.</p> <p>After the files are downloaded, we can extract the archives.</p> <div class="codehilite"><pre><span></span><code><span class="err">!</span><span class="n">tar</span> <span class="n">xf</span> <span class="n">images</span><span class="o">.</span><span class="n">tar</span> <span class="err">!</span><span class="n">unzip</span> <span class="o">-</span><span class="n">qq</span> <span class="o">~/</span><span class="n">stanfordextra_v12</span><span class="o">.</span><span class="n">zip</span> </code></pre></div> <hr /> <h2 id="imports">Imports</h2> <div class="codehilite"><pre><span></span><code><span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">layers</span> <span class="kn">import</span> <span class="nn">keras</span> <span class="kn">from</span> <span class="nn">imgaug.augmentables.kps</span> <span class="kn">import</span> <span class="n">KeypointsOnImage</span> <span class="kn">from</span> <span class="nn">imgaug.augmentables.kps</span> <span class="kn">import</span> <span class="n">Keypoint</span> <span class="kn">import</span> <span class="nn">imgaug.augmenters</span> <span class="k">as</span> <span class="nn">iaa</span> <span class="kn">from</span> <span class="nn">PIL</span> <span class="kn">import</span> <span class="n">Image</span> <span class="kn">from</span> <span class="nn">sklearn.model_selection</span> <span class="kn">import</span> <span class="n">train_test_split</span> <span class="kn">from</span> <span class="nn">matplotlib</span> <span class="kn">import</span> <span class="n">pyplot</span> <span class="k">as</span> <span class="n">plt</span> <span class="kn">import</span> <span class="nn">pandas</span> <span class="k">as</span> <span class="nn">pd</span> <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> <span class="kn">import</span> <span class="nn">json</span> <span class="kn">import</span> <span class="nn">os</span> </code></pre></div> <hr /> <h2 id="define-hyperparameters">Define hyperparameters</h2> <div class="codehilite"><pre><span></span><code><span class="n">IMG_SIZE</span> <span class="o">=</span> <span class="mi">224</span> <span class="n">BATCH_SIZE</span> <span class="o">=</span> <span class="mi">64</span> <span class="n">EPOCHS</span> <span class="o">=</span> <span class="mi">5</span> <span class="n">NUM_KEYPOINTS</span> <span class="o">=</span> <span class="mi">24</span> <span class="o">*</span> <span class="mi">2</span> <span class="c1"># 24 pairs each having x and y coordinates</span> </code></pre></div> <hr /> <h2 id="load-data">Load data</h2> <p>The authors also provide a metadata file that specifies additional information about the keypoints, like color information, animal pose name, etc. We will load this file in a <code>pandas</code> dataframe to extract information for visualization purposes.</p> <div class="codehilite"><pre><span></span><code><span class="n">IMG_DIR</span> <span class="o">=</span> <span class="s2">"Images"</span> <span class="n">JSON</span> <span class="o">=</span> <span class="s2">"StanfordExtra_V12/StanfordExtra_v12.json"</span> <span class="n">KEYPOINT_DEF</span> <span class="o">=</span> <span class="p">(</span> <span class="s2">"https://github.com/benjiebob/StanfordExtra/raw/master/keypoint_definitions.csv"</span> <span class="p">)</span> <span class="c1"># Load the ground-truth annotations.</span> <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">JSON</span><span class="p">)</span> <span class="k">as</span> <span class="n">infile</span><span class="p">:</span> <span class="n">json_data</span> <span class="o">=</span> <span class="n">json</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">infile</span><span class="p">)</span> <span class="c1"># Set up a dictionary, mapping all the ground-truth information</span> <span class="c1"># with respect to the path of the image.</span> <span class="n">json_dict</span> <span class="o">=</span> <span class="p">{</span><span class="n">i</span><span class="p">[</span><span class="s2">"img_path"</span><span class="p">]:</span> <span class="n">i</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">json_data</span><span class="p">}</span> </code></pre></div> <p>A single entry of <code>json_dict</code> looks like the following:</p> <div class="codehilite"><pre><span></span><code>'n02085782-Japanese_spaniel/n02085782_2886.jpg': {'img_bbox': [205, 20, 116, 201], 'img_height': 272, 'img_path': 'n02085782-Japanese_spaniel/n02085782_2886.jpg', 'img_width': 350, 'is_multiple_dogs': False, 'joints': [[108.66666666666667, 252.0, 1], [147.66666666666666, 229.0, 1], [163.5, 208.5, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0], [54.0, 244.0, 1], [77.33333333333333, 225.33333333333334, 1], [79.0, 196.5, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [150.66666666666666, 86.66666666666667, 1], [88.66666666666667, 73.0, 1], [116.0, 106.33333333333333, 1], [109.0, 123.33333333333333, 1], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0], [0, 0, 0]], 'seg': ...} </code></pre></div> <p>In this example, the keys we are interested in are:</p> <ul> <li><code>img_path</code></li> <li><code>joints</code></li> </ul> <p>There are a total of 24 entries present inside <code>joints</code>. Each entry has 3 values:</p> <ul> <li>x-coordinate</li> <li>y-coordinate</li> <li>visibility flag of the keypoints (1 indicates visibility and 0 indicates non-visibility)</li> </ul> <p>As we can see <code>joints</code> contain multiple <code>[0, 0, 0]</code> entries which denote that those keypoints were not labeled. In this example, we will consider both non-visible as well as unlabeled keypoints in order to allow mini-batch learning.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Load the metdata definition file and preview it.</span> <span class="n">keypoint_def</span> <span class="o">=</span> <span class="n">pd</span><span class="o">.</span><span class="n">read_csv</span><span class="p">(</span><span class="n">KEYPOINT_DEF</span><span class="p">)</span> <span class="n">keypoint_def</span><span class="o">.</span><span class="n">head</span><span class="p">()</span> <span class="c1"># Extract the colours and labels.</span> <span class="n">colours</span> <span class="o">=</span> <span class="n">keypoint_def</span><span class="p">[</span><span class="s2">"Hex colour"</span><span class="p">]</span><span class="o">.</span><span class="n">values</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span> <span class="n">colours</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"#"</span> <span class="o">+</span> <span class="n">colour</span> <span class="k">for</span> <span class="n">colour</span> <span class="ow">in</span> <span class="n">colours</span><span class="p">]</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">keypoint_def</span><span class="p">[</span><span class="s2">"Name"</span><span class="p">]</span><span class="o">.</span><span class="n">values</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span> <span class="c1"># Utility for reading an image and for getting its annotations.</span> <span class="k">def</span> <span class="nf">get_dog</span><span class="p">(</span><span class="n">name</span><span class="p">):</span> <span class="n">data</span> <span class="o">=</span> <span class="n">json_dict</span><span class="p">[</span><span class="n">name</span><span class="p">]</span> <span class="n">img_data</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">imread</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">IMG_DIR</span><span class="p">,</span> <span class="n">data</span><span class="p">[</span><span class="s2">"img_path"</span><span class="p">]))</span> <span class="c1"># If the image is RGBA convert it to RGB.</span> <span class="k">if</span> <span class="n">img_data</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">==</span> <span class="mi">4</span><span class="p">:</span> <span class="n">img_data</span> <span class="o">=</span> <span class="n">img_data</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">)</span> <span class="n">img_data</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">fromarray</span><span class="p">(</span><span class="n">img_data</span><span class="p">)</span> <span class="n">img_data</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">img_data</span><span class="o">.</span><span class="n">convert</span><span class="p">(</span><span class="s2">"RGB"</span><span class="p">))</span> <span class="n">data</span><span class="p">[</span><span class="s2">"img_data"</span><span class="p">]</span> <span class="o">=</span> <span class="n">img_data</span> <span class="k">return</span> <span class="n">data</span> </code></pre></div> <hr /> <h2 id="visualize-data">Visualize data</h2> <p>Now, we write a utility function to visualize the images and their keypoints.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Parts of this code come from here:</span> <span class="c1"># https://github.com/benjiebob/StanfordExtra/blob/master/demo.ipynb</span> <span class="k">def</span> <span class="nf">visualize_keypoints</span><span class="p">(</span><span class="n">images</span><span class="p">,</span> <span class="n">keypoints</span><span class="p">):</span> <span class="n">fig</span><span class="p">,</span> <span class="n">axes</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">(</span><span class="n">nrows</span><span class="o">=</span><span class="nb">len</span><span class="p">(</span><span class="n">images</span><span class="p">),</span> <span class="n">ncols</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">16</span><span class="p">,</span> <span class="mi">12</span><span class="p">))</span> <span class="p">[</span><span class="n">ax</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">"off"</span><span class="p">)</span> <span class="k">for</span> <span class="n">ax</span> <span class="ow">in</span> <span class="n">np</span><span class="o">.</span><span class="n">ravel</span><span class="p">(</span><span class="n">axes</span><span class="p">)]</span> <span class="k">for</span> <span class="p">(</span><span class="n">ax_orig</span><span class="p">,</span> <span class="n">ax_all</span><span class="p">),</span> <span class="n">image</span><span class="p">,</span> <span class="n">current_keypoint</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">axes</span><span class="p">,</span> <span class="n">images</span><span class="p">,</span> <span class="n">keypoints</span><span class="p">):</span> <span class="n">ax_orig</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">image</span><span class="p">)</span> <span class="n">ax_all</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">image</span><span class="p">)</span> <span class="c1"># If the keypoints were formed by `imgaug` then the coordinates need</span> <span class="c1"># to be iterated differently.</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">current_keypoint</span><span class="p">,</span> <span class="n">KeypointsOnImage</span><span class="p">):</span> <span class="k">for</span> <span class="n">idx</span><span class="p">,</span> <span class="n">kp</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">current_keypoint</span><span class="o">.</span><span class="n">keypoints</span><span class="p">):</span> <span class="n">ax_all</span><span class="o">.</span><span class="n">scatter</span><span class="p">(</span> <span class="p">[</span><span class="n">kp</span><span class="o">.</span><span class="n">x</span><span class="p">],</span> <span class="p">[</span><span class="n">kp</span><span class="o">.</span><span class="n">y</span><span class="p">],</span> <span class="n">c</span><span class="o">=</span><span class="n">colours</span><span class="p">[</span><span class="n">idx</span><span class="p">],</span> <span class="n">marker</span><span class="o">=</span><span class="s2">"x"</span><span class="p">,</span> <span class="n">s</span><span class="o">=</span><span class="mi">50</span><span class="p">,</span> <span class="n">linewidths</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="p">)</span> <span class="k">else</span><span class="p">:</span> <span class="n">current_keypoint</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">current_keypoint</span><span class="p">)</span> <span class="c1"># Since the last entry is the visibility flag, we discard it.</span> <span class="n">current_keypoint</span> <span class="o">=</span> <span class="n">current_keypoint</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">]</span> <span class="k">for</span> <span class="n">idx</span><span class="p">,</span> <span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">current_keypoint</span><span class="p">):</span> <span class="n">ax_all</span><span class="o">.</span><span class="n">scatter</span><span class="p">([</span><span class="n">x</span><span class="p">],</span> <span class="p">[</span><span class="n">y</span><span class="p">],</span> <span class="n">c</span><span class="o">=</span><span class="n">colours</span><span class="p">[</span><span class="n">idx</span><span class="p">],</span> <span class="n">marker</span><span class="o">=</span><span class="s2">"x"</span><span class="p">,</span> <span class="n">s</span><span class="o">=</span><span class="mi">50</span><span class="p">,</span> <span class="n">linewidths</span><span class="o">=</span><span class="mi">5</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">tight_layout</span><span class="p">(</span><span class="n">pad</span><span class="o">=</span><span class="mf">2.0</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> <span class="c1"># Select four samples randomly for visualization.</span> <span class="n">samples</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">json_dict</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span> <span class="n">num_samples</span> <span class="o">=</span> <span class="mi">4</span> <span class="n">selected_samples</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="n">samples</span><span class="p">,</span> <span class="n">num_samples</span><span class="p">,</span> <span class="n">replace</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> <span class="n">images</span><span class="p">,</span> <span class="n">keypoints</span> <span class="o">=</span> <span class="p">[],</span> <span class="p">[]</span> <span class="k">for</span> <span class="n">sample</span> <span class="ow">in</span> <span class="n">selected_samples</span><span class="p">:</span> <span class="n">data</span> <span class="o">=</span> <span class="n">get_dog</span><span class="p">(</span><span class="n">sample</span><span class="p">)</span> <span class="n">image</span> <span class="o">=</span> <span class="n">data</span><span class="p">[</span><span class="s2">"img_data"</span><span class="p">]</span> <span class="n">keypoint</span> <span class="o">=</span> <span class="n">data</span><span class="p">[</span><span class="s2">"joints"</span><span class="p">]</span> <span class="n">images</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">image</span><span class="p">)</span> <span class="n">keypoints</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">keypoint</span><span class="p">)</span> <span class="n">visualize_keypoints</span><span class="p">(</span><span class="n">images</span><span class="p">,</span> <span class="n">keypoints</span><span class="p">)</span> </code></pre></div> <p><img alt="png" src="/img/examples/vision/keypoint_detection/keypoint_detection_18_0.png" /></p> <p>The plots show that we have images of non-uniform sizes, which is expected in most real-world scenarios. However, if we resize these images to have a uniform shape (for instance (224 x 224)) their ground-truth annotations will also be affected. The same applies if we apply any geometric transformation (horizontal flip, for e.g.) to an image. Fortunately, <code>imgaug</code> provides utilities that can handle this issue. In the next section, we will write a data generator inheriting the <a href="https://keras.io/api/utils/python_utils/#sequence-class"><code>keras.utils.Sequence</code></a> class that applies data augmentation on batches of data using <code>imgaug</code>.</p> <hr /> <h2 id="prepare-data-generator">Prepare data generator</h2> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">KeyPointsDataset</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">PyDataset</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">image_keys</span><span class="p">,</span> <span class="n">aug</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">BATCH_SIZE</span><span class="p">,</span> <span class="n">train</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">image_keys</span> <span class="o">=</span> <span class="n">image_keys</span> <span class="bp">self</span><span class="o">.</span><span class="n">aug</span> <span class="o">=</span> <span class="n">aug</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="o">=</span> <span class="n">batch_size</span> <span class="bp">self</span><span class="o">.</span><span class="n">train</span> <span class="o">=</span> <span class="n">train</span> <span class="bp">self</span><span class="o">.</span><span class="n">on_epoch_end</span><span class="p">()</span> <span class="k">def</span> <span class="fm">__len__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="k">return</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">image_keys</span><span class="p">)</span> <span class="o">//</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="k">def</span> <span class="nf">on_epoch_end</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="bp">self</span><span class="o">.</span><span class="n">indexes</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">image_keys</span><span class="p">))</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">train</span><span class="p">:</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">indexes</span><span class="p">)</span> <span class="k">def</span> <span class="fm">__getitem__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">index</span><span class="p">):</span> <span class="n">indexes</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">indexes</span><span class="p">[</span><span class="n">index</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span> <span class="p">:</span> <span class="p">(</span><span class="n">index</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span><span class="p">]</span> <span class="n">image_keys_temp</span> <span class="o">=</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">image_keys</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="n">indexes</span><span class="p">]</span> <span class="p">(</span><span class="n">images</span><span class="p">,</span> <span class="n">keypoints</span><span class="p">)</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">__data_generation</span><span class="p">(</span><span class="n">image_keys_temp</span><span class="p">)</span> <span class="k">return</span> <span class="p">(</span><span class="n">images</span><span class="p">,</span> <span class="n">keypoints</span><span class="p">)</span> <span class="k">def</span> <span class="nf">__data_generation</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">image_keys_temp</span><span class="p">):</span> <span class="n">batch_images</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">IMG_SIZE</span><span class="p">,</span> <span class="n">IMG_SIZE</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"int"</span><span class="p">)</span> <span class="n">batch_keypoints</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">NUM_KEYPOINTS</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"float32"</span> <span class="p">)</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">key</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">image_keys_temp</span><span class="p">):</span> <span class="n">data</span> <span class="o">=</span> <span class="n">get_dog</span><span class="p">(</span><span class="n">key</span><span class="p">)</span> <span class="n">current_keypoint</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">data</span><span class="p">[</span><span class="s2">"joints"</span><span class="p">])[:,</span> <span class="p">:</span><span class="mi">2</span><span class="p">]</span> <span class="n">kps</span> <span class="o">=</span> <span class="p">[]</span> <span class="c1"># To apply our data augmentation pipeline, we first need to</span> <span class="c1"># form Keypoint objects with the original coordinates.</span> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">current_keypoint</span><span class="p">)):</span> <span class="n">kps</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">Keypoint</span><span class="p">(</span><span class="n">x</span><span class="o">=</span><span class="n">current_keypoint</span><span class="p">[</span><span class="n">j</span><span class="p">][</span><span class="mi">0</span><span class="p">],</span> <span class="n">y</span><span class="o">=</span><span class="n">current_keypoint</span><span class="p">[</span><span class="n">j</span><span class="p">][</span><span class="mi">1</span><span class="p">]))</span> <span class="c1"># We then project the original image and its keypoint coordinates.</span> <span class="n">current_image</span> <span class="o">=</span> <span class="n">data</span><span class="p">[</span><span class="s2">"img_data"</span><span class="p">]</span> <span class="n">kps_obj</span> <span class="o">=</span> <span class="n">KeypointsOnImage</span><span class="p">(</span><span class="n">kps</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="n">current_image</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="c1"># Apply the augmentation pipeline.</span> <span class="p">(</span><span class="n">new_image</span><span class="p">,</span> <span class="n">new_kps_obj</span><span class="p">)</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">aug</span><span class="p">(</span><span class="n">image</span><span class="o">=</span><span class="n">current_image</span><span class="p">,</span> <span class="n">keypoints</span><span class="o">=</span><span class="n">kps_obj</span><span class="p">)</span> <span class="n">batch_images</span><span class="p">[</span><span class="n">i</span><span class="p">,]</span> <span class="o">=</span> <span class="n">new_image</span> <span class="c1"># Parse the coordinates from the new keypoint object.</span> <span class="n">kp_temp</span> <span class="o">=</span> <span class="p">[]</span> <span class="k">for</span> <span class="n">keypoint</span> <span class="ow">in</span> <span class="n">new_kps_obj</span><span class="p">:</span> <span class="n">kp_temp</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">nan_to_num</span><span class="p">(</span><span class="n">keypoint</span><span class="o">.</span><span class="n">x</span><span class="p">))</span> <span class="n">kp_temp</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">nan_to_num</span><span class="p">(</span><span class="n">keypoint</span><span class="o">.</span><span class="n">y</span><span class="p">))</span> <span class="c1"># More on why this reshaping later.</span> <span class="n">batch_keypoints</span><span class="p">[</span><span class="n">i</span><span class="p">,]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">kp_temp</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">24</span> <span class="o">*</span> <span class="mi">2</span><span class="p">)</span> <span class="c1"># Scale the coordinates to [0, 1] range.</span> <span class="n">batch_keypoints</span> <span class="o">=</span> <span class="n">batch_keypoints</span> <span class="o">/</span> <span class="n">IMG_SIZE</span> <span class="k">return</span> <span class="p">(</span><span class="n">batch_images</span><span class="p">,</span> <span class="n">batch_keypoints</span><span class="p">)</span> </code></pre></div> <p>To know more about how to operate with keypoints in <code>imgaug</code> check out <a href="https://imgaug.readthedocs.io/en/latest/source/examples_keypoints.html">this document</a>.</p> <hr /> <h2 id="define-augmentation-transforms">Define augmentation transforms</h2> <div class="codehilite"><pre><span></span><code><span class="n">train_aug</span> <span class="o">=</span> <span class="n">iaa</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span> <span class="p">[</span> <span class="n">iaa</span><span class="o">.</span><span class="n">Resize</span><span class="p">(</span><span class="n">IMG_SIZE</span><span class="p">,</span> <span class="n">interpolation</span><span class="o">=</span><span class="s2">"linear"</span><span class="p">),</span> <span class="n">iaa</span><span class="o">.</span><span class="n">Fliplr</span><span class="p">(</span><span class="mf">0.3</span><span class="p">),</span> <span class="c1"># `Sometimes()` applies a function randomly to the inputs with</span> <span class="c1"># a given probability (0.3, in this case).</span> <span class="n">iaa</span><span class="o">.</span><span class="n">Sometimes</span><span class="p">(</span><span class="mf">0.3</span><span class="p">,</span> <span class="n">iaa</span><span class="o">.</span><span class="n">Affine</span><span class="p">(</span><span class="n">rotate</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">scale</span><span class="o">=</span><span class="p">(</span><span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.7</span><span class="p">))),</span> <span class="p">]</span> <span class="p">)</span> <span class="n">test_aug</span> <span class="o">=</span> <span class="n">iaa</span><span class="o">.</span><span class="n">Sequential</span><span class="p">([</span><span class="n">iaa</span><span class="o">.</span><span class="n">Resize</span><span class="p">(</span><span class="n">IMG_SIZE</span><span class="p">,</span> <span class="n">interpolation</span><span class="o">=</span><span class="s2">"linear"</span><span class="p">)])</span> </code></pre></div> <hr /> <h2 id="create-training-and-validation-splits">Create training and validation splits</h2> <div class="codehilite"><pre><span></span><code><span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="n">samples</span><span class="p">)</span> <span class="n">train_keys</span><span class="p">,</span> <span class="n">validation_keys</span> <span class="o">=</span> <span class="p">(</span> <span class="n">samples</span><span class="p">[</span><span class="nb">int</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">samples</span><span class="p">)</span> <span class="o">*</span> <span class="mf">0.15</span><span class="p">)</span> <span class="p">:],</span> <span class="n">samples</span><span class="p">[:</span> <span class="nb">int</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">samples</span><span class="p">)</span> <span class="o">*</span> <span class="mf">0.15</span><span class="p">)],</span> <span class="p">)</span> </code></pre></div> <hr /> <h2 id="data-generator-investigation">Data generator investigation</h2> <div class="codehilite"><pre><span></span><code><span class="n">train_dataset</span> <span class="o">=</span> <span class="n">KeyPointsDataset</span><span class="p">(</span> <span class="n">train_keys</span><span class="p">,</span> <span class="n">train_aug</span><span class="p">,</span> <span class="n">workers</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">use_multiprocessing</span><span class="o">=</span><span class="kc">True</span> <span class="p">)</span> <span class="n">validation_dataset</span> <span class="o">=</span> <span class="n">KeyPointsDataset</span><span class="p">(</span> <span class="n">validation_keys</span><span class="p">,</span> <span class="n">test_aug</span><span class="p">,</span> <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">workers</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">use_multiprocessing</span><span class="o">=</span><span class="kc">True</span> <span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Total batches in training set: </span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">train_dataset</span><span class="p">)</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Total batches in validation set: </span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">validation_dataset</span><span class="p">)</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> <span class="n">sample_images</span><span class="p">,</span> <span class="n">sample_keypoints</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="nb">iter</span><span class="p">(</span><span class="n">train_dataset</span><span class="p">))</span> <span class="k">assert</span> <span class="n">sample_keypoints</span><span class="o">.</span><span class="n">max</span><span class="p">()</span> <span class="o">==</span> <span class="mf">1.0</span> <span class="k">assert</span> <span class="n">sample_keypoints</span><span class="o">.</span><span class="n">min</span><span class="p">()</span> <span class="o">==</span> <span class="mf">0.0</span> <span class="n">sample_keypoints</span> <span class="o">=</span> <span class="n">sample_keypoints</span><span class="p">[:</span><span class="mi">4</span><span class="p">]</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">24</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span> <span class="o">*</span> <span class="n">IMG_SIZE</span> <span class="n">visualize_keypoints</span><span class="p">(</span><span class="n">sample_images</span><span class="p">[:</span><span class="mi">4</span><span class="p">],</span> <span class="n">sample_keypoints</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Total batches in training set: 166 Total batches in validation set: 29 </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/keypoint_detection/keypoint_detection_28_1.png" /></p> <hr /> <h2 id="model-building">Model building</h2> <p>The <a href="http://vision.stanford.edu/aditya86/ImageNetDogs/">Stanford dogs dataset</a> (on which the StanfordExtra dataset is based) was built using the <a href="http://image-net.org/">ImageNet-1k dataset</a>. So, it is likely that the models pretrained on the ImageNet-1k dataset would be useful for this task. We will use a MobileNetV2 pre-trained on this dataset as a backbone to extract meaningful features from the images and then pass those to a custom regression head for predicting coordinates.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">get_model</span><span class="p">():</span> <span class="c1"># Load the pre-trained weights of MobileNetV2 and freeze the weights</span> <span class="n">backbone</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">applications</span><span class="o">.</span><span class="n">MobileNetV2</span><span class="p">(</span> <span class="n">weights</span><span class="o">=</span><span class="s2">"imagenet"</span><span class="p">,</span> <span class="n">include_top</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">input_shape</span><span class="o">=</span><span class="p">(</span><span class="n">IMG_SIZE</span><span class="p">,</span> <span class="n">IMG_SIZE</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="p">)</span> <span class="n">backbone</span><span class="o">.</span><span class="n">trainable</span> <span class="o">=</span> <span class="kc">False</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Input</span><span class="p">((</span><span class="n">IMG_SIZE</span><span class="p">,</span> <span class="n">IMG_SIZE</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">applications</span><span class="o">.</span><span class="n">mobilenet_v2</span><span class="o">.</span><span class="n">preprocess_input</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">backbone</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="mf">0.3</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">SeparableConv2D</span><span class="p">(</span> <span class="n">NUM_KEYPOINTS</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">strides</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span> <span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">SeparableConv2D</span><span class="p">(</span> <span class="n">NUM_KEYPOINTS</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">strides</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"sigmoid"</span> <span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"keypoint_detector"</span><span class="p">)</span> </code></pre></div> <p>Our custom network is fully-convolutional which makes it more parameter-friendly than the same version of the network having fully-connected dense layers.</p> <div class="codehilite"><pre><span></span><code><span class="n">get_model</span><span class="p">()</span><span class="o">.</span><span class="n">summary</span><span class="p">()</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_224_no_top.h5 9406464/9406464 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step </code></pre></div> </div> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold">Model: "keypoint_detector"</span> </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ ┃<span style="font-weight: bold"> Layer (type) </span>┃<span style="font-weight: bold"> Output Shape </span>┃<span style="font-weight: bold"> Param # </span>┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ │ input_layer_1 (<span style="color: #0087ff; text-decoration-color: #0087ff">InputLayer</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">224</span>, <span style="color: #00af00; text-decoration-color: #00af00">224</span>, <span style="color: #00af00; text-decoration-color: #00af00">3</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ true_divide (<span style="color: #0087ff; text-decoration-color: #0087ff">TrueDivide</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">224</span>, <span style="color: #00af00; text-decoration-color: #00af00">224</span>, <span style="color: #00af00; text-decoration-color: #00af00">3</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ subtract (<span style="color: #0087ff; text-decoration-color: #0087ff">Subtract</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">224</span>, <span style="color: #00af00; text-decoration-color: #00af00">224</span>, <span style="color: #00af00; text-decoration-color: #00af00">3</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ mobilenetv2_1.00_224 │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">7</span>, <span style="color: #00af00; text-decoration-color: #00af00">7</span>, <span style="color: #00af00; text-decoration-color: #00af00">1280</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">2,257,984</span> │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Functional</span>) │ │ │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ dropout (<span style="color: #0087ff; text-decoration-color: #0087ff">Dropout</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">7</span>, <span style="color: #00af00; text-decoration-color: #00af00">7</span>, <span style="color: #00af00; text-decoration-color: #00af00">1280</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ separable_conv2d │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">3</span>, <span style="color: #00af00; text-decoration-color: #00af00">3</span>, <span style="color: #00af00; text-decoration-color: #00af00">48</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">93,488</span> │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">SeparableConv2D</span>) │ │ │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ separable_conv2d_1 │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">1</span>, <span style="color: #00af00; text-decoration-color: #00af00">1</span>, <span style="color: #00af00; text-decoration-color: #00af00">48</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">2,784</span> │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">SeparableConv2D</span>) │ │ │ └─────────────────────────────────┴───────────────────────────┴────────────┘ </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Total params: </span><span style="color: #00af00; text-decoration-color: #00af00">2,354,256</span> (8.98 MB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">96,272</span> (376.06 KB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Non-trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">2,257,984</span> (8.61 MB) </pre> <p>Notice the output shape of the network: <code>(None, 1, 1, 48)</code>. This is why we have reshaped the coordinates as: <code>batch_keypoints[i, :] = np.array(kp_temp).reshape(1, 1, 24 * 2)</code>.</p> <hr /> <h2 id="model-compilation-and-training">Model compilation and training</h2> <p>For this example, we will train the network only for five epochs.</p> <div class="codehilite"><pre><span></span><code><span class="n">model</span> <span class="o">=</span> <span class="n">get_model</span><span class="p">()</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">loss</span><span class="o">=</span><span class="s2">"mse"</span><span class="p">,</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="mf">1e-4</span><span class="p">))</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">train_dataset</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="n">validation_dataset</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="n">EPOCHS</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 1/5 166/166 ━━━━━━━━━━━━━━━━━━━━ 84s 415ms/step - loss: 0.1110 - val_loss: 0.0959 Epoch 2/5 166/166 ━━━━━━━━━━━━━━━━━━━━ 79s 472ms/step - loss: 0.0874 - val_loss: 0.0802 Epoch 3/5 166/166 ━━━━━━━━━━━━━━━━━━━━ 78s 463ms/step - loss: 0.0789 - val_loss: 0.0765 Epoch 4/5 166/166 ━━━━━━━━━━━━━━━━━━━━ 78s 467ms/step - loss: 0.0769 - val_loss: 0.0731 Epoch 5/5 166/166 ━━━━━━━━━━━━━━━━━━━━ 77s 464ms/step - loss: 0.0753 - val_loss: 0.0712 <keras.src.callbacks.history.History at 0x7fb5c4299ae0> </code></pre></div> </div> <hr /> <h2 id="make-predictions-and-visualize-them">Make predictions and visualize them</h2> <div class="codehilite"><pre><span></span><code><span class="n">sample_val_images</span><span class="p">,</span> <span class="n">sample_val_keypoints</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="nb">iter</span><span class="p">(</span><span class="n">validation_dataset</span><span class="p">))</span> <span class="n">sample_val_images</span> <span class="o">=</span> <span class="n">sample_val_images</span><span class="p">[:</span><span class="mi">4</span><span class="p">]</span> <span class="n">sample_val_keypoints</span> <span class="o">=</span> <span class="n">sample_val_keypoints</span><span class="p">[:</span><span class="mi">4</span><span class="p">]</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">24</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span> <span class="o">*</span> <span class="n">IMG_SIZE</span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">sample_val_images</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">24</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span> <span class="o">*</span> <span class="n">IMG_SIZE</span> <span class="c1"># Ground-truth</span> <span class="n">visualize_keypoints</span><span class="p">(</span><span class="n">sample_val_images</span><span class="p">,</span> <span class="n">sample_val_keypoints</span><span class="p">)</span> <span class="c1"># Predictions</span> <span class="n">visualize_keypoints</span><span class="p">(</span><span class="n">sample_val_images</span><span class="p">,</span> <span class="n">predictions</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 1/1 ━━━━━━━━━━━━━━━━━━━━ 7s 7s/step </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/keypoint_detection/keypoint_detection_37_1.png" /></p> <p><img alt="png" src="/img/examples/vision/keypoint_detection/keypoint_detection_37_2.png" /></p> <p>Predictions will likely improve with more training.</p> <hr /> <h2 id="going-further">Going further</h2> <ul> <li>Try using other augmentation transforms from <code>imgaug</code> to investigate how that changes the results.</li> <li>Here, we transferred the features from the pre-trained network linearly that is we did not <a href="https://keras.io/guides/transfer_learning/">fine-tune</a> it. You are encouraged to fine-tune it on this task and see if that improves the performance. You can also try different architectures and see how they affect the final performance.</li> </ul> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#keypoint-detection-with-transfer-learning'>Keypoint Detection with Transfer Learning</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#data-collection'>Data collection</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#imports'>Imports</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#define-hyperparameters'>Define hyperparameters</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#load-data'>Load data</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#visualize-data'>Visualize data</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#prepare-data-generator'>Prepare data generator</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#define-augmentation-transforms'>Define augmentation transforms</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#create-training-and-validation-splits'>Create training and validation splits</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#data-generator-investigation'>Data generator investigation</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#model-building'>Model building</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#model-compilation-and-training'>Model compilation and training</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#make-predictions-and-visualize-them'>Make predictions and visualize them</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#going-further'>Going further</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>