CINXE.COM
Metric learning for image similarity search
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/vision/metric_learning/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Metric learning for image similarity search"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Metric learning for image similarity search"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Metric learning for image similarity search</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink active" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink2" href="/examples/vision/image_classification_from_scratch/">Image classification from scratch</a> <a class="nav-sublink2" href="/examples/vision/mnist_convnet/">Simple MNIST convnet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_efficientnet_fine_tuning/">Image classification via fine-tuning with EfficientNet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_with_vision_transformer/">Image classification with Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/attention_mil_classification/">Classification using Attention-based Deep Multiple Instance Learning</a> <a class="nav-sublink2" href="/examples/vision/mlp_image_classification/">Image classification with modern MLP models</a> <a class="nav-sublink2" href="/examples/vision/mobilevit/">A mobile-friendly Transformer-based model for image classification</a> <a class="nav-sublink2" href="/examples/vision/xray_classification_with_tpus/">Pneumonia Classification on TPU</a> <a class="nav-sublink2" href="/examples/vision/cct/">Compact Convolutional Transformers</a> <a class="nav-sublink2" href="/examples/vision/convmixer/">Image classification with ConvMixer</a> <a class="nav-sublink2" href="/examples/vision/eanet/">Image classification with EANet (External Attention Transformer)</a> <a class="nav-sublink2" href="/examples/vision/involution/">Involutional neural networks</a> <a class="nav-sublink2" href="/examples/vision/perceiver_image_classification/">Image classification with Perceiver</a> <a class="nav-sublink2" href="/examples/vision/reptile/">Few-Shot learning with Reptile</a> <a class="nav-sublink2" href="/examples/vision/semisupervised_simclr/">Semi-supervised image classification using contrastive pretraining with SimCLR</a> <a class="nav-sublink2" href="/examples/vision/swin_transformers/">Image classification with Swin Transformers</a> <a class="nav-sublink2" href="/examples/vision/vit_small_ds/">Train a Vision Transformer on small datasets</a> <a class="nav-sublink2" href="/examples/vision/shiftvit/">A Vision Transformer without Attention</a> <a class="nav-sublink2" href="/examples/vision/image_classification_using_global_context_vision_transformer/">Image Classification using Global Context Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/temporal_latent_bottleneck/">When Recurrence meets Transformers</a> <a class="nav-sublink2" href="/examples/vision/oxford_pets_image_segmentation/">Image segmentation with a U-Net-like architecture</a> <a class="nav-sublink2" href="/examples/vision/deeplabv3_plus/">Multiclass semantic segmentation using DeepLabV3+</a> <a class="nav-sublink2" href="/examples/vision/basnet_segmentation/">Highly accurate boundaries segmentation using BASNet</a> <a class="nav-sublink2" href="/examples/vision/fully_convolutional_network/">Image Segmentation using Composable Fully-Convolutional Networks</a> <a class="nav-sublink2" href="/examples/vision/retinanet/">Object Detection with RetinaNet</a> <a class="nav-sublink2" href="/examples/vision/keypoint_detection/">Keypoint Detection with Transfer Learning</a> <a class="nav-sublink2" href="/examples/vision/object_detection_using_vision_transformer/">Object detection with Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/3D_image_classification/">3D image classification from CT scans</a> <a class="nav-sublink2" href="/examples/vision/depth_estimation/">Monocular depth estimation</a> <a class="nav-sublink2" href="/examples/vision/nerf/">3D volumetric rendering with NeRF</a> <a class="nav-sublink2" href="/examples/vision/pointnet_segmentation/">Point cloud segmentation with PointNet</a> <a class="nav-sublink2" href="/examples/vision/pointnet/">Point cloud classification</a> <a class="nav-sublink2" href="/examples/vision/captcha_ocr/">OCR model for reading Captchas</a> <a class="nav-sublink2" href="/examples/vision/handwriting_recognition/">Handwriting recognition</a> <a class="nav-sublink2" href="/examples/vision/autoencoder/">Convolutional autoencoder for image denoising</a> <a class="nav-sublink2" href="/examples/vision/mirnet/">Low-light image enhancement using MIRNet</a> <a class="nav-sublink2" href="/examples/vision/super_resolution_sub_pixel/">Image Super-Resolution using an Efficient Sub-Pixel CNN</a> <a class="nav-sublink2" href="/examples/vision/edsr/">Enhanced Deep Residual Networks for single-image super-resolution</a> <a class="nav-sublink2" href="/examples/vision/zero_dce/">Zero-DCE for low-light image enhancement</a> <a class="nav-sublink2" href="/examples/vision/cutmix/">CutMix data augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/mixup/">MixUp augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/randaugment/">RandAugment for Image Classification for Improved Robustness</a> <a class="nav-sublink2" href="/examples/vision/image_captioning/">Image captioning</a> <a class="nav-sublink2" href="/examples/vision/nl_image_search/">Natural language image search with a Dual Encoder</a> <a class="nav-sublink2" href="/examples/vision/visualizing_what_convnets_learn/">Visualizing what convnets learn</a> <a class="nav-sublink2" href="/examples/vision/integrated_gradients/">Model interpretability with Integrated Gradients</a> <a class="nav-sublink2" href="/examples/vision/probing_vits/">Investigating Vision Transformer representations</a> <a class="nav-sublink2" href="/examples/vision/grad_cam/">Grad-CAM class activation visualization</a> <a class="nav-sublink2" href="/examples/vision/near_dup_search/">Near-duplicate image search</a> <a class="nav-sublink2" href="/examples/vision/semantic_image_clustering/">Semantic Image Clustering</a> <a class="nav-sublink2" href="/examples/vision/siamese_contrastive/">Image similarity estimation using a Siamese Network with a contrastive loss</a> <a class="nav-sublink2" href="/examples/vision/siamese_network/">Image similarity estimation using a Siamese Network with a triplet loss</a> <a class="nav-sublink2 active" href="/examples/vision/metric_learning/">Metric learning for image similarity search</a> <a class="nav-sublink2" href="/examples/vision/metric_learning_tf_similarity/">Metric learning for image similarity search using TensorFlow Similarity</a> <a class="nav-sublink2" href="/examples/vision/nnclr/">Self-supervised contrastive learning with NNCLR</a> <a class="nav-sublink2" href="/examples/vision/video_classification/">Video Classification with a CNN-RNN Architecture</a> <a class="nav-sublink2" href="/examples/vision/conv_lstm/">Next-Frame Video Prediction with Convolutional LSTMs</a> <a class="nav-sublink2" href="/examples/vision/video_transformers/">Video Classification with Transformers</a> <a class="nav-sublink2" href="/examples/vision/vivit/">Video Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/bit/">Image Classification using BigTransfer (BiT)</a> <a class="nav-sublink2" href="/examples/vision/gradient_centralization/">Gradient Centralization for Better Training Performance</a> <a class="nav-sublink2" href="/examples/vision/token_learner/">Learning to tokenize in Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/knowledge_distillation/">Knowledge Distillation</a> <a class="nav-sublink2" href="/examples/vision/fixres/">FixRes: Fixing train-test resolution discrepancy</a> <a class="nav-sublink2" href="/examples/vision/cait/">Class Attention Image Transformers with LayerScale</a> <a class="nav-sublink2" href="/examples/vision/patch_convnet/">Augmenting convnets with aggregated attention</a> <a class="nav-sublink2" href="/examples/vision/learnable_resizer/">Learning to Resize</a> <a class="nav-sublink2" href="/examples/vision/adamatch/">Semi-supervision and domain adaptation with AdaMatch</a> <a class="nav-sublink2" href="/examples/vision/barlow_twins/">Barlow Twins for Contrastive SSL</a> <a class="nav-sublink2" href="/examples/vision/consistency_training/">Consistency training with supervision</a> <a class="nav-sublink2" href="/examples/vision/deit/">Distilling Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/focal_modulation_network/">Focal Modulation: A replacement for Self-Attention</a> <a class="nav-sublink2" href="/examples/vision/forwardforward/">Using the Forward-Forward Algorithm for Image Classification</a> <a class="nav-sublink2" href="/examples/vision/masked_image_modeling/">Masked image modeling with Autoencoders</a> <a class="nav-sublink2" href="/examples/vision/sam/">Segment Anything Model with 🤗Transformers</a> <a class="nav-sublink2" href="/examples/vision/segformer/">Semantic segmentation with SegFormer and Hugging Face Transformers</a> <a class="nav-sublink2" href="/examples/vision/simsiam/">Self-supervised contrastive learning with SimSiam</a> <a class="nav-sublink2" href="/examples/vision/supervised-contrastive-learning/">Supervised Contrastive Learning</a> <a class="nav-sublink2" href="/examples/vision/yolov8/">Efficient Object Detection with YOLOV8 and KerasCV</a> <a class="nav-sublink" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparam Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/vision/'>Computer Vision</a> / Metric learning for image similarity search </div> <div class='k-content'> <h1 id="metric-learning-for-image-similarity-search">Metric learning for image similarity search</h1> <p><strong>Author:</strong> <a href="https://twitter.com/mat_kelcey">Mat Kelcey</a><br> <strong>Date created:</strong> 2020/06/05<br> <strong>Last modified:</strong> 2020/06/09<br> <strong>Description:</strong> Example of using similarity metric learning on CIFAR-10 images.</p> <div class='example_version_banner keras_3'>ⓘ This example uses Keras 3</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/vision/ipynb/metric_learning.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/examples/vision/metric_learning.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="overview">Overview</h2> <p>Metric learning aims to train models that can embed inputs into a high-dimensional space such that "similar" inputs, as defined by the training scheme, are located close to each other. These models once trained can produce embeddings for downstream systems where such similarity is useful; examples include as a ranking signal for search or as a form of pretrained embedding model for another supervised problem.</p> <p>For a more detailed overview of metric learning see:</p> <ul> <li><a href="http://contrib.scikit-learn.org/metric-learn/introduction.html">What is metric learning?</a></li> <li><a href="https://www.youtube.com/watch?v=Jb4Ewl5RzkI">"Using crossentropy for metric learning" tutorial</a></li> </ul> <hr /> <h2 id="setup">Setup</h2> <p>Set Keras backend to tensorflow.</p> <div class="codehilite"><pre><span></span><code><span class="kn">import</span><span class="w"> </span><span class="nn">os</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s2">"KERAS_BACKEND"</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"tensorflow"</span> <span class="kn">import</span><span class="w"> </span><span class="nn">random</span> <span class="kn">import</span><span class="w"> </span><span class="nn">matplotlib.pyplot</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">plt</span> <span class="kn">import</span><span class="w"> </span><span class="nn">numpy</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">np</span> <span class="kn">import</span><span class="w"> </span><span class="nn">tensorflow</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">tf</span> <span class="kn">from</span><span class="w"> </span><span class="nn">collections</span><span class="w"> </span><span class="kn">import</span> <span class="n">defaultdict</span> <span class="kn">from</span><span class="w"> </span><span class="nn">PIL</span><span class="w"> </span><span class="kn">import</span> <span class="n">Image</span> <span class="kn">from</span><span class="w"> </span><span class="nn">sklearn.metrics</span><span class="w"> </span><span class="kn">import</span> <span class="n">ConfusionMatrixDisplay</span> <span class="kn">import</span><span class="w"> </span><span class="nn">keras</span> <span class="kn">from</span><span class="w"> </span><span class="nn">keras</span><span class="w"> </span><span class="kn">import</span> <span class="n">layers</span> </code></pre></div> <hr /> <h2 id="dataset">Dataset</h2> <p>For this example we will be using the <a href="https://www.cs.toronto.edu/~kriz/cifar.html">CIFAR-10</a> dataset.</p> <div class="codehilite"><pre><span></span><code><span class="kn">from</span><span class="w"> </span><span class="nn">keras.datasets</span><span class="w"> </span><span class="kn">import</span> <span class="n">cifar10</span> <span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">),</span> <span class="p">(</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">)</span> <span class="o">=</span> <span class="n">cifar10</span><span class="o">.</span><span class="n">load_data</span><span class="p">()</span> <span class="n">x_train</span> <span class="o">=</span> <span class="n">x_train</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">"float32"</span><span class="p">)</span> <span class="o">/</span> <span class="mf">255.0</span> <span class="n">y_train</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="n">y_train</span><span class="p">)</span> <span class="n">x_test</span> <span class="o">=</span> <span class="n">x_test</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">"float32"</span><span class="p">)</span> <span class="o">/</span> <span class="mf">255.0</span> <span class="n">y_test</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="n">y_test</span><span class="p">)</span> </code></pre></div> <p>To get a sense of the dataset we can visualise a grid of 25 random examples.</p> <div class="codehilite"><pre><span></span><code><span class="n">height_width</span> <span class="o">=</span> <span class="mi">32</span> <span class="k">def</span><span class="w"> </span><span class="nf">show_collage</span><span class="p">(</span><span class="n">examples</span><span class="p">):</span> <span class="n">box_size</span> <span class="o">=</span> <span class="n">height_width</span> <span class="o">+</span> <span class="mi">2</span> <span class="n">num_rows</span><span class="p">,</span> <span class="n">num_cols</span> <span class="o">=</span> <span class="n">examples</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="mi">2</span><span class="p">]</span> <span class="n">collage</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">new</span><span class="p">(</span> <span class="n">mode</span><span class="o">=</span><span class="s2">"RGB"</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="n">num_cols</span> <span class="o">*</span> <span class="n">box_size</span><span class="p">,</span> <span class="n">num_rows</span> <span class="o">*</span> <span class="n">box_size</span><span class="p">),</span> <span class="n">color</span><span class="o">=</span><span class="p">(</span><span class="mi">250</span><span class="p">,</span> <span class="mi">250</span><span class="p">,</span> <span class="mi">250</span><span class="p">),</span> <span class="p">)</span> <span class="k">for</span> <span class="n">row_idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_rows</span><span class="p">):</span> <span class="k">for</span> <span class="n">col_idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_cols</span><span class="p">):</span> <span class="n">array</span> <span class="o">=</span> <span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">examples</span><span class="p">[</span><span class="n">row_idx</span><span class="p">,</span> <span class="n">col_idx</span><span class="p">])</span> <span class="o">*</span> <span class="mi">255</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">)</span> <span class="n">collage</span><span class="o">.</span><span class="n">paste</span><span class="p">(</span> <span class="n">Image</span><span class="o">.</span><span class="n">fromarray</span><span class="p">(</span><span class="n">array</span><span class="p">),</span> <span class="p">(</span><span class="n">col_idx</span> <span class="o">*</span> <span class="n">box_size</span><span class="p">,</span> <span class="n">row_idx</span> <span class="o">*</span> <span class="n">box_size</span><span class="p">)</span> <span class="p">)</span> <span class="c1"># Double size for visualisation.</span> <span class="n">collage</span> <span class="o">=</span> <span class="n">collage</span><span class="o">.</span><span class="n">resize</span><span class="p">((</span><span class="mi">2</span> <span class="o">*</span> <span class="n">num_cols</span> <span class="o">*</span> <span class="n">box_size</span><span class="p">,</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">num_rows</span> <span class="o">*</span> <span class="n">box_size</span><span class="p">))</span> <span class="k">return</span> <span class="n">collage</span> <span class="c1"># Show a collage of 5x5 random images.</span> <span class="n">sample_idxs</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">50000</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">5</span><span class="p">))</span> <span class="n">examples</span> <span class="o">=</span> <span class="n">x_train</span><span class="p">[</span><span class="n">sample_idxs</span><span class="p">]</span> <span class="n">show_collage</span><span class="p">(</span><span class="n">examples</span><span class="p">)</span> </code></pre></div> <p><img alt="png" src="/img/examples/vision/metric_learning/metric_learning_7_0.png" /></p> <p>Metric learning provides training data not as explicit <code>(X, y)</code> pairs but instead uses multiple instances that are related in the way we want to express similarity. In our example we will use instances of the same class to represent similarity; a single training instance will not be one image, but a pair of images of the same class. When referring to the images in this pair we'll use the common metric learning names of the <code>anchor</code> (a randomly chosen image) and the <code>positive</code> (another randomly chosen image of the same class).</p> <p>To facilitate this we need to build a form of lookup that maps from classes to the instances of that class. When generating data for training we will sample from this lookup.</p> <div class="codehilite"><pre><span></span><code><span class="n">class_idx_to_train_idxs</span> <span class="o">=</span> <span class="n">defaultdict</span><span class="p">(</span><span class="nb">list</span><span class="p">)</span> <span class="k">for</span> <span class="n">y_train_idx</span><span class="p">,</span> <span class="n">y</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">y_train</span><span class="p">):</span> <span class="n">class_idx_to_train_idxs</span><span class="p">[</span><span class="n">y</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">y_train_idx</span><span class="p">)</span> <span class="n">class_idx_to_test_idxs</span> <span class="o">=</span> <span class="n">defaultdict</span><span class="p">(</span><span class="nb">list</span><span class="p">)</span> <span class="k">for</span> <span class="n">y_test_idx</span><span class="p">,</span> <span class="n">y</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">y_test</span><span class="p">):</span> <span class="n">class_idx_to_test_idxs</span><span class="p">[</span><span class="n">y</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">y_test_idx</span><span class="p">)</span> </code></pre></div> <p>For this example we are using the simplest approach to training; a batch will consist of <code>(anchor, positive)</code> pairs spread across the classes. The goal of learning will be to move the anchor and positive pairs closer together and further away from other instances in the batch. In this case the batch size will be dictated by the number of classes; for CIFAR-10 this is 10.</p> <div class="codehilite"><pre><span></span><code><span class="n">num_classes</span> <span class="o">=</span> <span class="mi">10</span> <span class="k">class</span><span class="w"> </span><span class="nc">AnchorPositivePairs</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">Sequence</span><span class="p">):</span> <span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">num_batches</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_batches</span> <span class="o">=</span> <span class="n">num_batches</span> <span class="k">def</span><span class="w"> </span><span class="fm">__len__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_batches</span> <span class="k">def</span><span class="w"> </span><span class="fm">__getitem__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">_idx</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">empty</span><span class="p">((</span><span class="mi">2</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">,</span> <span class="n">height_width</span><span class="p">,</span> <span class="n">height_width</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="k">for</span> <span class="n">class_idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_classes</span><span class="p">):</span> <span class="n">examples_for_class</span> <span class="o">=</span> <span class="n">class_idx_to_train_idxs</span><span class="p">[</span><span class="n">class_idx</span><span class="p">]</span> <span class="n">anchor_idx</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="n">examples_for_class</span><span class="p">)</span> <span class="n">positive_idx</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="n">examples_for_class</span><span class="p">)</span> <span class="k">while</span> <span class="n">positive_idx</span> <span class="o">==</span> <span class="n">anchor_idx</span><span class="p">:</span> <span class="n">positive_idx</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="n">examples_for_class</span><span class="p">)</span> <span class="n">x</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="n">class_idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">x_train</span><span class="p">[</span><span class="n">anchor_idx</span><span class="p">]</span> <span class="n">x</span><span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="n">class_idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">x_train</span><span class="p">[</span><span class="n">positive_idx</span><span class="p">]</span> <span class="k">return</span> <span class="n">x</span> </code></pre></div> <p>We can visualise a batch in another collage. The top row shows randomly chosen anchors from the 10 classes, the bottom row shows the corresponding 10 positives.</p> <div class="codehilite"><pre><span></span><code><span class="n">examples</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="nb">iter</span><span class="p">(</span><span class="n">AnchorPositivePairs</span><span class="p">(</span><span class="n">num_batches</span><span class="o">=</span><span class="mi">1</span><span class="p">)))</span> <span class="n">show_collage</span><span class="p">(</span><span class="n">examples</span><span class="p">)</span> </code></pre></div> <p><img alt="png" src="/img/examples/vision/metric_learning/metric_learning_13_0.png" /></p> <hr /> <h2 id="embedding-model">Embedding model</h2> <p>We define a custom model with a <code>train_step</code> that first embeds both anchors and positives and then uses their pairwise dot products as logits for a softmax.</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span><span class="w"> </span><span class="nc">EmbeddingModel</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">):</span> <span class="k">def</span><span class="w"> </span><span class="nf">train_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">):</span> <span class="c1"># Note: Workaround for open issue, to be removed.</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="nb">tuple</span><span class="p">):</span> <span class="n">data</span> <span class="o">=</span> <span class="n">data</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="n">anchors</span><span class="p">,</span> <span class="n">positives</span> <span class="o">=</span> <span class="n">data</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">data</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="k">with</span> <span class="n">tf</span><span class="o">.</span><span class="n">GradientTape</span><span class="p">()</span> <span class="k">as</span> <span class="n">tape</span><span class="p">:</span> <span class="c1"># Run both anchors and positives through model.</span> <span class="n">anchor_embeddings</span> <span class="o">=</span> <span class="bp">self</span><span class="p">(</span><span class="n">anchors</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="n">positive_embeddings</span> <span class="o">=</span> <span class="bp">self</span><span class="p">(</span><span class="n">positives</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="c1"># Calculate cosine similarity between anchors and positives. As they have</span> <span class="c1"># been normalised this is just the pair wise dot products.</span> <span class="n">similarities</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span> <span class="s2">"ae,pe->ap"</span><span class="p">,</span> <span class="n">anchor_embeddings</span><span class="p">,</span> <span class="n">positive_embeddings</span> <span class="p">)</span> <span class="c1"># Since we intend to use these as logits we scale them by a temperature.</span> <span class="c1"># This value would normally be chosen as a hyper parameter.</span> <span class="n">temperature</span> <span class="o">=</span> <span class="mf">0.2</span> <span class="n">similarities</span> <span class="o">/=</span> <span class="n">temperature</span> <span class="c1"># We use these similarities as logits for a softmax. The labels for</span> <span class="c1"># this call are just the sequence [0, 1, 2, ..., num_classes] since we</span> <span class="c1"># want the main diagonal values, which correspond to the anchor/positive</span> <span class="c1"># pairs, to be high. This loss will move embeddings for the</span> <span class="c1"># anchor/positive pairs together and move all other pairs apart.</span> <span class="n">sparse_labels</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">num_classes</span><span class="p">)</span> <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">compute_loss</span><span class="p">(</span><span class="n">y</span><span class="o">=</span><span class="n">sparse_labels</span><span class="p">,</span> <span class="n">y_pred</span><span class="o">=</span><span class="n">similarities</span><span class="p">)</span> <span class="c1"># Calculate gradients and apply via optimizer.</span> <span class="n">gradients</span> <span class="o">=</span> <span class="n">tape</span><span class="o">.</span><span class="n">gradient</span><span class="p">(</span><span class="n">loss</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">trainable_variables</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="n">gradients</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">trainable_variables</span><span class="p">))</span> <span class="c1"># Update and return metrics (specifically the one for the loss value).</span> <span class="k">for</span> <span class="n">metric</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">metrics</span><span class="p">:</span> <span class="c1"># Calling `self.compile` will by default add a [`keras.metrics.Mean`](/api/metrics/metrics_wrappers#mean-class) loss</span> <span class="k">if</span> <span class="n">metric</span><span class="o">.</span><span class="n">name</span> <span class="o">==</span> <span class="s2">"loss"</span><span class="p">:</span> <span class="n">metric</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">loss</span><span class="p">)</span> <span class="k">else</span><span class="p">:</span> <span class="n">metric</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">sparse_labels</span><span class="p">,</span> <span class="n">similarities</span><span class="p">)</span> <span class="k">return</span> <span class="p">{</span><span class="n">m</span><span class="o">.</span><span class="n">name</span><span class="p">:</span> <span class="n">m</span><span class="o">.</span><span class="n">result</span><span class="p">()</span> <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">metrics</span><span class="p">}</span> </code></pre></div> <p>Next we describe the architecture that maps from an image to an embedding. This model simply consists of a sequence of 2d convolutions followed by global pooling with a final linear projection to an embedding space. As is common in metric learning we normalise the embeddings so that we can use simple dot products to measure similarity. For simplicity this model is intentionally small.</p> <div class="codehilite"><pre><span></span><code><span class="n">inputs</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">height_width</span><span class="p">,</span> <span class="n">height_width</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="n">filters</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">strides</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="n">filters</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">strides</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="n">filters</span><span class="o">=</span><span class="mi">128</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">strides</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">GlobalAveragePooling2D</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> <span class="n">embeddings</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">units</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="kc">None</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">embeddings</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">UnitNormalization</span><span class="p">()(</span><span class="n">embeddings</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">EmbeddingModel</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">embeddings</span><span class="p">)</span> </code></pre></div> <p>Finally we run the training. On a Google Colab GPU instance this takes about a minute.</p> <div class="codehilite"><pre><span></span><code><span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">learning_rate</span><span class="o">=</span><span class="mf">1e-3</span><span class="p">),</span> <span class="n">loss</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">SparseCategoricalCrossentropy</span><span class="p">(</span><span class="n">from_logits</span><span class="o">=</span><span class="kc">True</span><span class="p">),</span> <span class="p">)</span> <span class="n">history</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">AnchorPositivePairs</span><span class="p">(</span><span class="n">num_batches</span><span class="o">=</span><span class="mi">1000</span><span class="p">),</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">20</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">history</span><span class="o">.</span><span class="n">history</span><span class="p">[</span><span class="s2">"loss"</span><span class="p">])</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 1/20 77/1000 ━[37m━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 2.2962 WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1700589927.295343 3724442 device_compiler.h:187] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process. 1000/1000 ━━━━━━━━━━━━━━━━━━━━ 6s 2ms/step - loss: 2.2504 Epoch 2/20 1000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 2.1068 Epoch 3/20 1000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 2.0646 Epoch 4/20 1000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 2.0210 Epoch 5/20 1000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.9857 Epoch 6/20 1000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.9543 Epoch 7/20 1000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.9175 Epoch 8/20 1000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.8740 Epoch 9/20 1000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.8474 Epoch 10/20 1000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.8380 Epoch 11/20 1000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.8146 Epoch 12/20 1000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.7658 Epoch 13/20 1000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.7512 Epoch 14/20 1000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.7671 Epoch 15/20 1000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.7245 Epoch 16/20 1000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.7001 Epoch 17/20 1000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.7099 Epoch 18/20 1000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.6775 Epoch 19/20 1000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.6547 Epoch 20/20 1000/1000 ━━━━━━━━━━━━━━━━━━━━ 2s 2ms/step - loss: 1.6356 </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/metric_learning/metric_learning_19_3.png" /></p> <hr /> <h2 id="testing">Testing</h2> <p>We can review the quality of this model by applying it to the test set and considering near neighbours in the embedding space.</p> <p>First we embed the test set and calculate all near neighbours. Recall that since the embeddings are unit length we can calculate cosine similarity via dot products.</p> <div class="codehilite"><pre><span></span><code><span class="n">near_neighbours_per_example</span> <span class="o">=</span> <span class="mi">10</span> <span class="n">embeddings</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">x_test</span><span class="p">)</span> <span class="n">gram_matrix</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">einsum</span><span class="p">(</span><span class="s2">"ae,be->ab"</span><span class="p">,</span> <span class="n">embeddings</span><span class="p">,</span> <span class="n">embeddings</span><span class="p">)</span> <span class="n">near_neighbours</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">argsort</span><span class="p">(</span><span class="n">gram_matrix</span><span class="o">.</span><span class="n">T</span><span class="p">)[:,</span> <span class="o">-</span><span class="p">(</span><span class="n">near_neighbours_per_example</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="p">:]</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 313/313 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step </code></pre></div> </div> <p>As a visual check of these embeddings we can build a collage of the near neighbours for 5 random examples. The first column of the image below is a randomly selected image, the following 10 columns show the nearest neighbours in order of similarity.</p> <div class="codehilite"><pre><span></span><code><span class="n">num_collage_examples</span> <span class="o">=</span> <span class="mi">5</span> <span class="n">examples</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">empty</span><span class="p">(</span> <span class="p">(</span> <span class="n">num_collage_examples</span><span class="p">,</span> <span class="n">near_neighbours_per_example</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">height_width</span><span class="p">,</span> <span class="n">height_width</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">,</span> <span class="p">)</span> <span class="k">for</span> <span class="n">row_idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_collage_examples</span><span class="p">):</span> <span class="n">examples</span><span class="p">[</span><span class="n">row_idx</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="n">x_test</span><span class="p">[</span><span class="n">row_idx</span><span class="p">]</span> <span class="n">anchor_near_neighbours</span> <span class="o">=</span> <span class="nb">reversed</span><span class="p">(</span><span class="n">near_neighbours</span><span class="p">[</span><span class="n">row_idx</span><span class="p">][:</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span> <span class="k">for</span> <span class="n">col_idx</span><span class="p">,</span> <span class="n">nn_idx</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">anchor_near_neighbours</span><span class="p">):</span> <span class="n">examples</span><span class="p">[</span><span class="n">row_idx</span><span class="p">,</span> <span class="n">col_idx</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="n">x_test</span><span class="p">[</span><span class="n">nn_idx</span><span class="p">]</span> <span class="n">show_collage</span><span class="p">(</span><span class="n">examples</span><span class="p">)</span> </code></pre></div> <p><img alt="png" src="/img/examples/vision/metric_learning/metric_learning_23_0.png" /></p> <p>We can also get a quantified view of the performance by considering the correctness of near neighbours in terms of a confusion matrix.</p> <p>Let us sample 10 examples from each of the 10 classes and consider their near neighbours as a form of prediction; that is, does the example and its near neighbours share the same class?</p> <p>We observe that each animal class does generally well, and is confused the most with the other animal classes. The vehicle classes follow the same pattern.</p> <div class="codehilite"><pre><span></span><code><span class="n">confusion_matrix</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">num_classes</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">))</span> <span class="c1"># For each class.</span> <span class="k">for</span> <span class="n">class_idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_classes</span><span class="p">):</span> <span class="c1"># Consider 10 examples.</span> <span class="n">example_idxs</span> <span class="o">=</span> <span class="n">class_idx_to_test_idxs</span><span class="p">[</span><span class="n">class_idx</span><span class="p">][:</span><span class="mi">10</span><span class="p">]</span> <span class="k">for</span> <span class="n">y_test_idx</span> <span class="ow">in</span> <span class="n">example_idxs</span><span class="p">:</span> <span class="c1"># And count the classes of its near neighbours.</span> <span class="k">for</span> <span class="n">nn_idx</span> <span class="ow">in</span> <span class="n">near_neighbours</span><span class="p">[</span><span class="n">y_test_idx</span><span class="p">][:</span><span class="o">-</span><span class="mi">1</span><span class="p">]:</span> <span class="n">nn_class_idx</span> <span class="o">=</span> <span class="n">y_test</span><span class="p">[</span><span class="n">nn_idx</span><span class="p">]</span> <span class="n">confusion_matrix</span><span class="p">[</span><span class="n">class_idx</span><span class="p">,</span> <span class="n">nn_class_idx</span><span class="p">]</span> <span class="o">+=</span> <span class="mi">1</span> <span class="c1"># Display a confusion matrix.</span> <span class="n">labels</span> <span class="o">=</span> <span class="p">[</span> <span class="s2">"Airplane"</span><span class="p">,</span> <span class="s2">"Automobile"</span><span class="p">,</span> <span class="s2">"Bird"</span><span class="p">,</span> <span class="s2">"Cat"</span><span class="p">,</span> <span class="s2">"Deer"</span><span class="p">,</span> <span class="s2">"Dog"</span><span class="p">,</span> <span class="s2">"Frog"</span><span class="p">,</span> <span class="s2">"Horse"</span><span class="p">,</span> <span class="s2">"Ship"</span><span class="p">,</span> <span class="s2">"Truck"</span><span class="p">,</span> <span class="p">]</span> <span class="n">disp</span> <span class="o">=</span> <span class="n">ConfusionMatrixDisplay</span><span class="p">(</span><span class="n">confusion_matrix</span><span class="o">=</span><span class="n">confusion_matrix</span><span class="p">,</span> <span class="n">display_labels</span><span class="o">=</span><span class="n">labels</span><span class="p">)</span> <span class="n">disp</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">include_values</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">cmap</span><span class="o">=</span><span class="s2">"viridis"</span><span class="p">,</span> <span class="n">ax</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">xticks_rotation</span><span class="o">=</span><span class="s2">"vertical"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> </code></pre></div> <p><img alt="png" src="/img/examples/vision/metric_learning/metric_learning_25_0.png" /></p> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#metric-learning-for-image-similarity-search'>Metric learning for image similarity search</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#overview'>Overview</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#setup'>Setup</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#dataset'>Dataset</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#embedding-model'>Embedding model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#testing'>Testing</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>