CINXE.COM
Near-duplicate image search
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/vision/near_dup_search/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Near-duplicate image search"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Near-duplicate image search"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Near-duplicate image search</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink active" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink2" href="/examples/vision/image_classification_from_scratch/">Image classification from scratch</a> <a class="nav-sublink2" href="/examples/vision/mnist_convnet/">Simple MNIST convnet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_efficientnet_fine_tuning/">Image classification via fine-tuning with EfficientNet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_with_vision_transformer/">Image classification with Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/attention_mil_classification/">Classification using Attention-based Deep Multiple Instance Learning</a> <a class="nav-sublink2" href="/examples/vision/mlp_image_classification/">Image classification with modern MLP models</a> <a class="nav-sublink2" href="/examples/vision/mobilevit/">A mobile-friendly Transformer-based model for image classification</a> <a class="nav-sublink2" href="/examples/vision/xray_classification_with_tpus/">Pneumonia Classification on TPU</a> <a class="nav-sublink2" href="/examples/vision/cct/">Compact Convolutional Transformers</a> <a class="nav-sublink2" href="/examples/vision/convmixer/">Image classification with ConvMixer</a> <a class="nav-sublink2" href="/examples/vision/eanet/">Image classification with EANet (External Attention Transformer)</a> <a class="nav-sublink2" href="/examples/vision/involution/">Involutional neural networks</a> <a class="nav-sublink2" href="/examples/vision/perceiver_image_classification/">Image classification with Perceiver</a> <a class="nav-sublink2" href="/examples/vision/reptile/">Few-Shot learning with Reptile</a> <a class="nav-sublink2" href="/examples/vision/semisupervised_simclr/">Semi-supervised image classification using contrastive pretraining with SimCLR</a> <a class="nav-sublink2" href="/examples/vision/swin_transformers/">Image classification with Swin Transformers</a> <a class="nav-sublink2" href="/examples/vision/vit_small_ds/">Train a Vision Transformer on small datasets</a> <a class="nav-sublink2" href="/examples/vision/shiftvit/">A Vision Transformer without Attention</a> <a class="nav-sublink2" href="/examples/vision/image_classification_using_global_context_vision_transformer/">Image Classification using Global Context Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/oxford_pets_image_segmentation/">Image segmentation with a U-Net-like architecture</a> <a class="nav-sublink2" href="/examples/vision/deeplabv3_plus/">Multiclass semantic segmentation using DeepLabV3+</a> <a class="nav-sublink2" href="/examples/vision/basnet_segmentation/">Highly accurate boundaries segmentation using BASNet</a> <a class="nav-sublink2" href="/examples/vision/fully_convolutional_network/">Image Segmentation using Composable Fully-Convolutional Networks</a> <a class="nav-sublink2" href="/examples/vision/retinanet/">Object Detection with RetinaNet</a> <a class="nav-sublink2" href="/examples/vision/keypoint_detection/">Keypoint Detection with Transfer Learning</a> <a class="nav-sublink2" href="/examples/vision/object_detection_using_vision_transformer/">Object detection with Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/3D_image_classification/">3D image classification from CT scans</a> <a class="nav-sublink2" href="/examples/vision/depth_estimation/">Monocular depth estimation</a> <a class="nav-sublink2" href="/examples/vision/nerf/">3D volumetric rendering with NeRF</a> <a class="nav-sublink2" href="/examples/vision/pointnet_segmentation/">Point cloud segmentation with PointNet</a> <a class="nav-sublink2" href="/examples/vision/pointnet/">Point cloud classification</a> <a class="nav-sublink2" href="/examples/vision/captcha_ocr/">OCR model for reading Captchas</a> <a class="nav-sublink2" href="/examples/vision/handwriting_recognition/">Handwriting recognition</a> <a class="nav-sublink2" href="/examples/vision/autoencoder/">Convolutional autoencoder for image denoising</a> <a class="nav-sublink2" href="/examples/vision/mirnet/">Low-light image enhancement using MIRNet</a> <a class="nav-sublink2" href="/examples/vision/super_resolution_sub_pixel/">Image Super-Resolution using an Efficient Sub-Pixel CNN</a> <a class="nav-sublink2" href="/examples/vision/edsr/">Enhanced Deep Residual Networks for single-image super-resolution</a> <a class="nav-sublink2" href="/examples/vision/zero_dce/">Zero-DCE for low-light image enhancement</a> <a class="nav-sublink2" href="/examples/vision/cutmix/">CutMix data augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/mixup/">MixUp augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/randaugment/">RandAugment for Image Classification for Improved Robustness</a> <a class="nav-sublink2" href="/examples/vision/image_captioning/">Image captioning</a> <a class="nav-sublink2" href="/examples/vision/nl_image_search/">Natural language image search with a Dual Encoder</a> <a class="nav-sublink2" href="/examples/vision/visualizing_what_convnets_learn/">Visualizing what convnets learn</a> <a class="nav-sublink2" href="/examples/vision/integrated_gradients/">Model interpretability with Integrated Gradients</a> <a class="nav-sublink2" href="/examples/vision/probing_vits/">Investigating Vision Transformer representations</a> <a class="nav-sublink2" href="/examples/vision/grad_cam/">Grad-CAM class activation visualization</a> <a class="nav-sublink2 active" href="/examples/vision/near_dup_search/">Near-duplicate image search</a> <a class="nav-sublink2" href="/examples/vision/semantic_image_clustering/">Semantic Image Clustering</a> <a class="nav-sublink2" href="/examples/vision/siamese_contrastive/">Image similarity estimation using a Siamese Network with a contrastive loss</a> <a class="nav-sublink2" href="/examples/vision/siamese_network/">Image similarity estimation using a Siamese Network with a triplet loss</a> <a class="nav-sublink2" href="/examples/vision/metric_learning/">Metric learning for image similarity search</a> <a class="nav-sublink2" href="/examples/vision/metric_learning_tf_similarity/">Metric learning for image similarity search using TensorFlow Similarity</a> <a class="nav-sublink2" href="/examples/vision/nnclr/">Self-supervised contrastive learning with NNCLR</a> <a class="nav-sublink2" href="/examples/vision/video_classification/">Video Classification with a CNN-RNN Architecture</a> <a class="nav-sublink2" href="/examples/vision/conv_lstm/">Next-Frame Video Prediction with Convolutional LSTMs</a> <a class="nav-sublink2" href="/examples/vision/video_transformers/">Video Classification with Transformers</a> <a class="nav-sublink2" href="/examples/vision/vivit/">Video Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/bit/">Image Classification using BigTransfer (BiT)</a> <a class="nav-sublink2" href="/examples/vision/gradient_centralization/">Gradient Centralization for Better Training Performance</a> <a class="nav-sublink2" href="/examples/vision/token_learner/">Learning to tokenize in Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/knowledge_distillation/">Knowledge Distillation</a> <a class="nav-sublink2" href="/examples/vision/fixres/">FixRes: Fixing train-test resolution discrepancy</a> <a class="nav-sublink2" href="/examples/vision/cait/">Class Attention Image Transformers with LayerScale</a> <a class="nav-sublink2" href="/examples/vision/patch_convnet/">Augmenting convnets with aggregated attention</a> <a class="nav-sublink2" href="/examples/vision/learnable_resizer/">Learning to Resize</a> <a class="nav-sublink2" href="/examples/vision/adamatch/">Semi-supervision and domain adaptation with AdaMatch</a> <a class="nav-sublink2" href="/examples/vision/barlow_twins/">Barlow Twins for Contrastive SSL</a> <a class="nav-sublink2" href="/examples/vision/consistency_training/">Consistency training with supervision</a> <a class="nav-sublink2" href="/examples/vision/deit/">Distilling Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/focal_modulation_network/">Focal Modulation: A replacement for Self-Attention</a> <a class="nav-sublink2" href="/examples/vision/forwardforward/">Using the Forward-Forward Algorithm for Image Classification</a> <a class="nav-sublink2" href="/examples/vision/masked_image_modeling/">Masked image modeling with Autoencoders</a> <a class="nav-sublink2" href="/examples/vision/sam/">Segment Anything Model with 🤗Transformers</a> <a class="nav-sublink2" href="/examples/vision/segformer/">Semantic segmentation with SegFormer and Hugging Face Transformers</a> <a class="nav-sublink2" href="/examples/vision/simsiam/">Self-supervised contrastive learning with SimSiam</a> <a class="nav-sublink2" href="/examples/vision/supervised-contrastive-learning/">Supervised Contrastive Learning</a> <a class="nav-sublink2" href="/examples/vision/temporal_latent_bottleneck/">When Recurrence meets Transformers</a> <a class="nav-sublink2" href="/examples/vision/yolov8/">Efficient Object Detection with YOLOV8 and KerasCV</a> <a class="nav-sublink" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparam Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/vision/'>Computer Vision</a> / Near-duplicate image search </div> <div class='k-content'> <h1 id="nearduplicate-image-search">Near-duplicate image search</h1> <p><strong>Author:</strong> <a href="https://twitter.com/RisingSayak">Sayak Paul</a><br> <strong>Date created:</strong> 2021/09/10<br> <strong>Last modified:</strong> 2023/08/30<br></p> <div class='example_version_banner keras_2'>ⓘ This example uses Keras 2</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/vision/ipynb/near_dup_search.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/examples/vision/near_dup_search.py"><strong>GitHub source</strong></a></p> <p><strong>Description:</strong> Building a near-duplicate image search utility using deep learning and locality-sensitive hashing.</p> <hr /> <h2 id="introduction">Introduction</h2> <p>Fetching similar images in (near) real time is an important use case of information retrieval systems. Some popular products utilizing it include Pinterest, Google Image Search, etc. In this example, we will build a similar image search utility using <a href="https://towardsdatascience.com/understanding-locality-sensitive-hashing-49f6d1f6134">Locality Sensitive Hashing</a> (LSH) and <a href="https://en.wikipedia.org/wiki/Random_projection">random projection</a> on top of the image representations computed by a pretrained image classifier. This kind of search engine is also known as a <em>near-duplicate (or near-dup) image detector</em>. We will also look into optimizing the inference performance of our search utility on GPU using <a href="https://developer.nvidia.com/tensorrt">TensorRT</a>.</p> <p>There are other examples under <a href="https://keras.io/examples/vision">keras.io/examples/vision</a> that are worth checking out in this regard:</p> <ul> <li><a href="https://keras.io/examples/vision/metric_learning">Metric learning for image similarity search</a></li> <li><a href="https://keras.io/examples/vision/siamese_network">Image similarity estimation using a Siamese Network with a triplet loss</a></li> </ul> <p>Finally, this example uses the following resource as a reference and as such reuses some of its code: <a href="https://towardsdatascience.com/locality-sensitive-hashing-for-music-search-f2f1940ace23">Locality Sensitive Hashing for Similar Item Search</a>.</p> <p><em>Note that in order to optimize the performance of our parser, you should have a GPU runtime available.</em></p> <hr /> <h2 id="setup">Setup</h2> <div class="codehilite"><pre><span></span><code><span class="err">!</span><span class="n">pip</span> <span class="n">install</span> <span class="n">tensorrt</span> </code></pre></div> <hr /> <h2 id="imports">Imports</h2> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span> <span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="nn">tf</span> <span class="kn">import</span> <span class="nn">tensorrt</span> <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> <span class="kn">import</span> <span class="nn">time</span> <span class="kn">import</span> <span class="nn">tensorflow_datasets</span> <span class="k">as</span> <span class="nn">tfds</span> <span class="n">tfds</span><span class="o">.</span><span class="n">disable_progress_bar</span><span class="p">()</span> </code></pre></div> <hr /> <h2 id="load-the-dataset-and-create-a-training-set-of-1000-images">Load the dataset and create a training set of 1,000 images</h2> <p>To keep the run time of the example short, we will be using a subset of 1,000 images from the <code>tf_flowers</code> dataset (available through <a href="https://www.tensorflow.org/datasets/catalog/tf_flowers">TensorFlow Datasets</a>) to build our vocabulary.</p> <div class="codehilite"><pre><span></span><code><span class="n">train_ds</span><span class="p">,</span> <span class="n">validation_ds</span> <span class="o">=</span> <span class="n">tfds</span><span class="o">.</span><span class="n">load</span><span class="p">(</span> <span class="s2">"tf_flowers"</span><span class="p">,</span> <span class="n">split</span><span class="o">=</span><span class="p">[</span><span class="s2">"train[:85%]"</span><span class="p">,</span> <span class="s2">"train[85%:]"</span><span class="p">],</span> <span class="n">as_supervised</span><span class="o">=</span><span class="kc">True</span> <span class="p">)</span> <span class="n">IMAGE_SIZE</span> <span class="o">=</span> <span class="mi">224</span> <span class="n">NUM_IMAGES</span> <span class="o">=</span> <span class="mi">1000</span> <span class="n">images</span> <span class="o">=</span> <span class="p">[]</span> <span class="n">labels</span> <span class="o">=</span> <span class="p">[]</span> <span class="k">for</span> <span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="n">label</span><span class="p">)</span> <span class="ow">in</span> <span class="n">train_ds</span><span class="o">.</span><span class="n">take</span><span class="p">(</span><span class="n">NUM_IMAGES</span><span class="p">):</span> <span class="n">image</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">resize</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="p">(</span><span class="n">IMAGE_SIZE</span><span class="p">,</span> <span class="n">IMAGE_SIZE</span><span class="p">))</span> <span class="n">images</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">image</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span> <span class="n">labels</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">label</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span> <span class="n">images</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">images</span><span class="p">)</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">labels</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="load-a-pretrained-model">Load a pre-trained model</h2> <p>In this section, we load an image classification model that was trained on the <code>tf_flowers</code> dataset. 85% of the total images were used to build the training set. For more details on the training, refer to <a href="https://github.com/sayakpaul/near-dup-parser/blob/main/bit-supervised-training.ipynb">this notebook</a>.</p> <p>The underlying model is a BiT-ResNet (proposed in <a href="https://arxiv.org/abs/1912.11370">Big Transfer (BiT): General Visual Representation Learning</a>). The BiT-ResNet family of models is known to provide excellent transfer performance across a wide variety of different downstream tasks.</p> <div class="codehilite"><pre><span></span><code><span class="err">!</span><span class="n">wget</span> <span class="o">-</span><span class="n">q</span> <span class="n">https</span><span class="p">:</span><span class="o">//</span><span class="n">github</span><span class="o">.</span><span class="n">com</span><span class="o">/</span><span class="n">sayakpaul</span><span class="o">/</span><span class="n">near</span><span class="o">-</span><span class="n">dup</span><span class="o">-</span><span class="n">parser</span><span class="o">/</span><span class="n">releases</span><span class="o">/</span><span class="n">download</span><span class="o">/</span><span class="n">v0</span><span class="mf">.1.0</span><span class="o">/</span><span class="n">flower_model_bit_0</span><span class="mf">.96875</span><span class="o">.</span><span class="n">zip</span> <span class="err">!</span><span class="n">unzip</span> <span class="o">-</span><span class="n">qq</span> <span class="n">flower_model_bit_0</span><span class="mf">.96875</span><span class="o">.</span><span class="n">zip</span> </code></pre></div> <div class="codehilite"><pre><span></span><code><span class="n">bit_model</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">load_model</span><span class="p">(</span><span class="s2">"flower_model_bit_0.96875"</span><span class="p">)</span> <span class="n">bit_model</span><span class="o">.</span><span class="n">count_params</span><span class="p">()</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>23510597 </code></pre></div> </div> <hr /> <h2 id="create-an-embedding-model">Create an embedding model</h2> <p>To retrieve similar images given a query image, we need to first generate vector representations of all the images involved. We do this via an embedding model that extracts output features from our pretrained classifier and normalizes the resulting feature vectors.</p> <div class="codehilite"><pre><span></span><code><span class="n">embedding_model</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span> <span class="p">[</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Input</span><span class="p">((</span><span class="n">IMAGE_SIZE</span><span class="p">,</span> <span class="n">IMAGE_SIZE</span><span class="p">,</span> <span class="mi">3</span><span class="p">)),</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Rescaling</span><span class="p">(</span><span class="n">scale</span><span class="o">=</span><span class="mf">1.0</span> <span class="o">/</span> <span class="mi">255</span><span class="p">),</span> <span class="n">bit_model</span><span class="o">.</span><span class="n">layers</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Normalization</span><span class="p">(</span><span class="n">mean</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">variance</span><span class="o">=</span><span class="mi">1</span><span class="p">),</span> <span class="p">],</span> <span class="n">name</span><span class="o">=</span><span class="s2">"embedding_model"</span><span class="p">,</span> <span class="p">)</span> <span class="n">embedding_model</span><span class="o">.</span><span class="n">summary</span><span class="p">()</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Model: "embedding_model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= rescaling (Rescaling) (None, 224, 224, 3) 0 _________________________________________________________________ keras_layer (KerasLayer) (None, 2048) 23500352 _________________________________________________________________ normalization (Normalization (None, 2048) 0 ================================================================= Total params: 23,500,352 Trainable params: 23,500,352 Non-trainable params: 0 _________________________________________________________________ </code></pre></div> </div> <p>Take note of the normalization layer inside the model. It is used to project the representation vectors to the space of unit-spheres.</p> <hr /> <h2 id="hashing-utilities">Hashing utilities</h2> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">hash_func</span><span class="p">(</span><span class="n">embedding</span><span class="p">,</span> <span class="n">random_vectors</span><span class="p">):</span> <span class="n">embedding</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">embedding</span><span class="p">)</span> <span class="c1"># Random projection.</span> <span class="n">bools</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">dot</span><span class="p">(</span><span class="n">embedding</span><span class="p">,</span> <span class="n">random_vectors</span><span class="p">)</span> <span class="o">></span> <span class="mi">0</span> <span class="k">return</span> <span class="p">[</span><span class="n">bool2int</span><span class="p">(</span><span class="n">bool_vec</span><span class="p">)</span> <span class="k">for</span> <span class="n">bool_vec</span> <span class="ow">in</span> <span class="n">bools</span><span class="p">]</span> <span class="k">def</span> <span class="nf">bool2int</span><span class="p">(</span><span class="n">x</span><span class="p">):</span> <span class="n">y</span> <span class="o">=</span> <span class="mi">0</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">x</span><span class="p">):</span> <span class="k">if</span> <span class="n">j</span><span class="p">:</span> <span class="n">y</span> <span class="o">+=</span> <span class="mi">1</span> <span class="o"><<</span> <span class="n">i</span> <span class="k">return</span> <span class="n">y</span> </code></pre></div> <p>The shape of the vectors coming out of <code>embedding_model</code> is <code>(2048,)</code>, and considering practical aspects (storage, retrieval performance, etc.) it is quite large. So, there arises a need to reduce the dimensionality of the embedding vectors without reducing their information content. This is where <em>random projection</em> comes into the picture. It is based on the principle that if the distance between a group of points on a given plane is <em>approximately</em> preserved, the dimensionality of that plane can further be reduced.</p> <p>Inside <code>hash_func()</code>, we first reduce the dimensionality of the embedding vectors. Then we compute the bitwise hash values of the images to determine their hash buckets. Images having same hash values are likely to go into the same hash bucket. From a deployment perspective, bitwise hash values are cheaper to store and operate on.</p> <hr /> <h2 id="query-utilities">Query utilities</h2> <p>The <code>Table</code> class is responsible for building a single hash table. Each entry in the hash table is a mapping between the reduced embedding of an image from our dataset and a unique identifier. Because our dimensionality reduction technique involves randomness, it can so happen that similar images are not mapped to the same hash bucket everytime the process run. To reduce this effect, we will take results from multiple tables into consideration – the number of tables and the reduction dimensionality are the key hyperparameters here.</p> <p>Crucially, you wouldn't reimplement locality-sensitive hashing yourself when working with real world applications. Instead, you'd likely use one of the following popular libraries:</p> <ul> <li><a href="https://github.com/google-research/google-research/tree/master/scann">ScaNN</a></li> <li><a href="https://github.com/spotify/annoy">Annoy</a></li> <li><a href="https://github.com/vdaas/vald">Vald</a></li> </ul> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">Table</span><span class="p">:</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">hash_size</span><span class="p">,</span> <span class="n">dim</span><span class="p">):</span> <span class="bp">self</span><span class="o">.</span><span class="n">table</span> <span class="o">=</span> <span class="p">{}</span> <span class="bp">self</span><span class="o">.</span><span class="n">hash_size</span> <span class="o">=</span> <span class="n">hash_size</span> <span class="bp">self</span><span class="o">.</span><span class="n">random_vectors</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randn</span><span class="p">(</span><span class="n">hash_size</span><span class="p">,</span> <span class="n">dim</span><span class="p">)</span><span class="o">.</span><span class="n">T</span> <span class="k">def</span> <span class="nf">add</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="nb">id</span><span class="p">,</span> <span class="n">vectors</span><span class="p">,</span> <span class="n">label</span><span class="p">):</span> <span class="c1"># Create a unique indentifier.</span> <span class="n">entry</span> <span class="o">=</span> <span class="p">{</span><span class="s2">"id_label"</span><span class="p">:</span> <span class="nb">str</span><span class="p">(</span><span class="nb">id</span><span class="p">)</span> <span class="o">+</span> <span class="s2">"_"</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">label</span><span class="p">)}</span> <span class="c1"># Compute the hash values.</span> <span class="n">hashes</span> <span class="o">=</span> <span class="n">hash_func</span><span class="p">(</span><span class="n">vectors</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">random_vectors</span><span class="p">)</span> <span class="c1"># Add the hash values to the current table.</span> <span class="k">for</span> <span class="n">h</span> <span class="ow">in</span> <span class="n">hashes</span><span class="p">:</span> <span class="k">if</span> <span class="n">h</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">table</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">table</span><span class="p">[</span><span class="n">h</span><span class="p">]</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">entry</span><span class="p">)</span> <span class="k">else</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">table</span><span class="p">[</span><span class="n">h</span><span class="p">]</span> <span class="o">=</span> <span class="p">[</span><span class="n">entry</span><span class="p">]</span> <span class="k">def</span> <span class="nf">query</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">vectors</span><span class="p">):</span> <span class="c1"># Compute hash value for the query vector.</span> <span class="n">hashes</span> <span class="o">=</span> <span class="n">hash_func</span><span class="p">(</span><span class="n">vectors</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">random_vectors</span><span class="p">)</span> <span class="n">results</span> <span class="o">=</span> <span class="p">[]</span> <span class="c1"># Loop over the query hashes and determine if they exist in</span> <span class="c1"># the current table.</span> <span class="k">for</span> <span class="n">h</span> <span class="ow">in</span> <span class="n">hashes</span><span class="p">:</span> <span class="k">if</span> <span class="n">h</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">table</span><span class="p">:</span> <span class="n">results</span><span class="o">.</span><span class="n">extend</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">table</span><span class="p">[</span><span class="n">h</span><span class="p">])</span> <span class="k">return</span> <span class="n">results</span> </code></pre></div> <p>In the following <code>LSH</code> class we will pack the utilities to have multiple hash tables.</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">LSH</span><span class="p">:</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">hash_size</span><span class="p">,</span> <span class="n">dim</span><span class="p">,</span> <span class="n">num_tables</span><span class="p">):</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_tables</span> <span class="o">=</span> <span class="n">num_tables</span> <span class="bp">self</span><span class="o">.</span><span class="n">tables</span> <span class="o">=</span> <span class="p">[]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_tables</span><span class="p">):</span> <span class="bp">self</span><span class="o">.</span><span class="n">tables</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">Table</span><span class="p">(</span><span class="n">hash_size</span><span class="p">,</span> <span class="n">dim</span><span class="p">))</span> <span class="k">def</span> <span class="nf">add</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="nb">id</span><span class="p">,</span> <span class="n">vectors</span><span class="p">,</span> <span class="n">label</span><span class="p">):</span> <span class="k">for</span> <span class="n">table</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">tables</span><span class="p">:</span> <span class="n">table</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="nb">id</span><span class="p">,</span> <span class="n">vectors</span><span class="p">,</span> <span class="n">label</span><span class="p">)</span> <span class="k">def</span> <span class="nf">query</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">vectors</span><span class="p">):</span> <span class="n">results</span> <span class="o">=</span> <span class="p">[]</span> <span class="k">for</span> <span class="n">table</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">tables</span><span class="p">:</span> <span class="n">results</span><span class="o">.</span><span class="n">extend</span><span class="p">(</span><span class="n">table</span><span class="o">.</span><span class="n">query</span><span class="p">(</span><span class="n">vectors</span><span class="p">))</span> <span class="k">return</span> <span class="n">results</span> </code></pre></div> <p>Now we can encapsulate the logic for building and operating with the master LSH table (a collection of many tables) inside a class. It has two methods:</p> <ul> <li><code>train()</code>: Responsible for building the final LSH table.</li> <li><code>query()</code>: Computes the number of matches given a query image and also quantifies the similarity score.</li> </ul> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">BuildLSHTable</span><span class="p">:</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> <span class="bp">self</span><span class="p">,</span> <span class="n">prediction_model</span><span class="p">,</span> <span class="n">concrete_function</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">hash_size</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">2048</span><span class="p">,</span> <span class="n">num_tables</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="p">):</span> <span class="bp">self</span><span class="o">.</span><span class="n">hash_size</span> <span class="o">=</span> <span class="n">hash_size</span> <span class="bp">self</span><span class="o">.</span><span class="n">dim</span> <span class="o">=</span> <span class="n">dim</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_tables</span> <span class="o">=</span> <span class="n">num_tables</span> <span class="bp">self</span><span class="o">.</span><span class="n">lsh</span> <span class="o">=</span> <span class="n">LSH</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">hash_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dim</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_tables</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">prediction_model</span> <span class="o">=</span> <span class="n">prediction_model</span> <span class="bp">self</span><span class="o">.</span><span class="n">concrete_function</span> <span class="o">=</span> <span class="n">concrete_function</span> <span class="k">def</span> <span class="nf">train</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">training_files</span><span class="p">):</span> <span class="k">for</span> <span class="nb">id</span><span class="p">,</span> <span class="n">training_file</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">training_files</span><span class="p">):</span> <span class="c1"># Unpack the data.</span> <span class="n">image</span><span class="p">,</span> <span class="n">label</span> <span class="o">=</span> <span class="n">training_file</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">image</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o"><</span> <span class="mi">4</span><span class="p">:</span> <span class="n">image</span> <span class="o">=</span> <span class="n">image</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="o">...</span><span class="p">]</span> <span class="c1"># Compute embeddings and update the LSH tables.</span> <span class="c1"># More on `self.concrete_function()` later.</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">concrete_function</span><span class="p">:</span> <span class="n">features</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">prediction_model</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">constant</span><span class="p">(</span><span class="n">image</span><span class="p">))[</span> <span class="s2">"normalization"</span> <span class="p">]</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> <span class="k">else</span><span class="p">:</span> <span class="n">features</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">prediction_model</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">image</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">lsh</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="nb">id</span><span class="p">,</span> <span class="n">features</span><span class="p">,</span> <span class="n">label</span><span class="p">)</span> <span class="k">def</span> <span class="nf">query</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">image</span><span class="p">,</span> <span class="n">verbose</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span> <span class="c1"># Compute the embeddings of the query image and fetch the results.</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">image</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o"><</span> <span class="mi">4</span><span class="p">:</span> <span class="n">image</span> <span class="o">=</span> <span class="n">image</span><span class="p">[</span><span class="kc">None</span><span class="p">,</span> <span class="o">...</span><span class="p">]</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">concrete_function</span><span class="p">:</span> <span class="n">features</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">prediction_model</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">constant</span><span class="p">(</span><span class="n">image</span><span class="p">))[</span> <span class="s2">"normalization"</span> <span class="p">]</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> <span class="k">else</span><span class="p">:</span> <span class="n">features</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">prediction_model</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">image</span><span class="p">)</span> <span class="n">results</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">lsh</span><span class="o">.</span><span class="n">query</span><span class="p">(</span><span class="n">features</span><span class="p">)</span> <span class="k">if</span> <span class="n">verbose</span><span class="p">:</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Matches:"</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">results</span><span class="p">))</span> <span class="c1"># Calculate Jaccard index to quantify the similarity.</span> <span class="n">counts</span> <span class="o">=</span> <span class="p">{}</span> <span class="k">for</span> <span class="n">r</span> <span class="ow">in</span> <span class="n">results</span><span class="p">:</span> <span class="k">if</span> <span class="n">r</span><span class="p">[</span><span class="s2">"id_label"</span><span class="p">]</span> <span class="ow">in</span> <span class="n">counts</span><span class="p">:</span> <span class="n">counts</span><span class="p">[</span><span class="n">r</span><span class="p">[</span><span class="s2">"id_label"</span><span class="p">]]</span> <span class="o">+=</span> <span class="mi">1</span> <span class="k">else</span><span class="p">:</span> <span class="n">counts</span><span class="p">[</span><span class="n">r</span><span class="p">[</span><span class="s2">"id_label"</span><span class="p">]]</span> <span class="o">=</span> <span class="mi">1</span> <span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="n">counts</span><span class="p">:</span> <span class="n">counts</span><span class="p">[</span><span class="n">k</span><span class="p">]</span> <span class="o">=</span> <span class="nb">float</span><span class="p">(</span><span class="n">counts</span><span class="p">[</span><span class="n">k</span><span class="p">])</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">dim</span> <span class="k">return</span> <span class="n">counts</span> </code></pre></div> <hr /> <h2 id="create-lsh-tables">Create LSH tables</h2> <p>With our helper utilities and classes implemented, we can now build our LSH table. Since we will be benchmarking performance between optimized and unoptimized embedding models, we will also warm up our GPU to avoid any unfair comparison.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Utility to warm up the GPU.</span> <span class="k">def</span> <span class="nf">warmup</span><span class="p">():</span> <span class="n">dummy_sample</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="n">IMAGE_SIZE</span><span class="p">,</span> <span class="n">IMAGE_SIZE</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">100</span><span class="p">):</span> <span class="n">_</span> <span class="o">=</span> <span class="n">embedding_model</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">dummy_sample</span><span class="p">)</span> </code></pre></div> <p>Now we can first do the GPU wam-up and proceed to build the master LSH table with <code>embedding_model</code>.</p> <div class="codehilite"><pre><span></span><code><span class="n">warmup</span><span class="p">()</span> <span class="n">training_files</span> <span class="o">=</span> <span class="nb">zip</span><span class="p">(</span><span class="n">images</span><span class="p">,</span> <span class="n">labels</span><span class="p">)</span> <span class="n">lsh_builder</span> <span class="o">=</span> <span class="n">BuildLSHTable</span><span class="p">(</span><span class="n">embedding_model</span><span class="p">)</span> <span class="n">lsh_builder</span><span class="o">.</span><span class="n">train</span><span class="p">(</span><span class="n">training_files</span><span class="p">)</span> </code></pre></div> <p>At the time of writing, the wall time was 54.1 seconds on a Tesla T4 GPU. This timing may vary based on the GPU you are using.</p> <hr /> <h2 id="optimize-the-model-with-tensorrt">Optimize the model with TensorRT</h2> <p>For NVIDIA-based GPUs, the <a href="https://docs.nvidia.com/deeplearning/frameworks/tf-trt-user-guide/index.html">TensorRT framework</a> can be used to dramatically enhance the inference latency by using various model optimization techniques like pruning, constant folding, layer fusion, and so on. Here we will use the <a href="https://www.tensorflow.org/api_docs/python/tf/experimental/tensorrt"><code>tf.experimental.tensorrt</code></a> module to optimize our embedding model.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># First serialize the embedding model as a SavedModel.</span> <span class="n">embedding_model</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="s2">"embedding_model"</span><span class="p">)</span> <span class="c1"># Initialize the conversion parameters.</span> <span class="n">params</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">experimental</span><span class="o">.</span><span class="n">tensorrt</span><span class="o">.</span><span class="n">ConversionParams</span><span class="p">(</span> <span class="n">precision_mode</span><span class="o">=</span><span class="s2">"FP16"</span><span class="p">,</span> <span class="n">maximum_cached_engines</span><span class="o">=</span><span class="mi">16</span> <span class="p">)</span> <span class="c1"># Run the conversion.</span> <span class="n">converter</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">experimental</span><span class="o">.</span><span class="n">tensorrt</span><span class="o">.</span><span class="n">Converter</span><span class="p">(</span> <span class="n">input_saved_model_dir</span><span class="o">=</span><span class="s2">"embedding_model"</span><span class="p">,</span> <span class="n">conversion_params</span><span class="o">=</span><span class="n">params</span> <span class="p">)</span> <span class="n">converter</span><span class="o">.</span><span class="n">convert</span><span class="p">()</span> <span class="n">converter</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="s2">"tensorrt_embedding_model"</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model. WARNING:tensorflow:Compiled the loaded model, but the compiled metrics have yet to be built. `model.compile_metrics` will be empty until you train or evaluate the model. INFO:tensorflow:Assets written to: embedding_model/assets INFO:tensorflow:Assets written to: embedding_model/assets INFO:tensorflow:Linked TensorRT version: (0, 0, 0) INFO:tensorflow:Linked TensorRT version: (0, 0, 0) INFO:tensorflow:Loaded TensorRT version: (0, 0, 0) INFO:tensorflow:Loaded TensorRT version: (0, 0, 0) INFO:tensorflow:Assets written to: tensorrt_embedding_model/assets INFO:tensorflow:Assets written to: tensorrt_embedding_model/assets </code></pre></div> </div> <p><strong>Notes on the parameters inside of <code>tf.experimental.tensorrt.ConversionParams()</code></strong>:</p> <ul> <li><code>precision_mode</code> defines the numerical precision of the operations in the to-be-converted model.</li> <li><code>maximum_cached_engines</code> specifies the maximum number of TRT engines that will be cached to handle dynamic operations (operations with unknown shapes).</li> </ul> <p>To learn more about the other options, refer to the <a href="https://www.tensorflow.org/api_docs/python/tf/experimental/tensorrt/ConversionParams">official documentation</a>. You can also explore the different quantization options provided by the <a href="https://www.tensorflow.org/api_docs/python/tf/experimental/tensorrt"><code>tf.experimental.tensorrt</code></a> module.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Load the converted model.</span> <span class="n">root</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">saved_model</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s2">"tensorrt_embedding_model"</span><span class="p">)</span> <span class="n">trt_model_function</span> <span class="o">=</span> <span class="n">root</span><span class="o">.</span><span class="n">signatures</span><span class="p">[</span><span class="s2">"serving_default"</span><span class="p">]</span> </code></pre></div> <hr /> <h2 id="build-lsh-tables-with-optimized-model">Build LSH tables with optimized model</h2> <div class="codehilite"><pre><span></span><code><span class="n">warmup</span><span class="p">()</span> <span class="n">training_files</span> <span class="o">=</span> <span class="nb">zip</span><span class="p">(</span><span class="n">images</span><span class="p">,</span> <span class="n">labels</span><span class="p">)</span> <span class="n">lsh_builder_trt</span> <span class="o">=</span> <span class="n">BuildLSHTable</span><span class="p">(</span><span class="n">trt_model_function</span><span class="p">,</span> <span class="n">concrete_function</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="n">lsh_builder_trt</span><span class="o">.</span><span class="n">train</span><span class="p">(</span><span class="n">training_files</span><span class="p">)</span> </code></pre></div> <p>Notice the difference in the wall time which is <strong>13.1 seconds</strong>. Earlier, with the unoptimized model it was <strong>54.1 seconds</strong>.</p> <p>We can take a closer look into one of the hash tables and get an idea of how they are represented.</p> <div class="codehilite"><pre><span></span><code><span class="n">idx</span> <span class="o">=</span> <span class="mi">0</span> <span class="k">for</span> <span class="nb">hash</span><span class="p">,</span> <span class="n">entry</span> <span class="ow">in</span> <span class="n">lsh_builder_trt</span><span class="o">.</span><span class="n">lsh</span><span class="o">.</span><span class="n">tables</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">table</span><span class="o">.</span><span class="n">items</span><span class="p">():</span> <span class="k">if</span> <span class="n">idx</span> <span class="o">==</span> <span class="mi">5</span><span class="p">:</span> <span class="k">break</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">entry</span><span class="p">)</span> <span class="o"><</span> <span class="mi">5</span><span class="p">:</span> <span class="nb">print</span><span class="p">(</span><span class="nb">hash</span><span class="p">,</span> <span class="n">entry</span><span class="p">)</span> <span class="n">idx</span> <span class="o">+=</span> <span class="mi">1</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>145 [{'id_label': '3_4'}, {'id_label': '727_3'}] 5 [{'id_label': '12_4'}] 128 [{'id_label': '30_2'}, {'id_label': '480_2'}] 208 [{'id_label': '34_2'}, {'id_label': '132_2'}, {'id_label': '984_2'}] 188 [{'id_label': '42_0'}, {'id_label': '135_3'}, {'id_label': '436_3'}, {'id_label': '670_3'}] </code></pre></div> </div> <hr /> <h2 id="visualize-results-on-validation-images">Visualize results on validation images</h2> <p>In this section we will first writing a couple of utility functions to visualize the similar image parsing process. Then we will benchmark the query performance of the models with and without optimization.</p> <p>First, we take 100 images from the validation set for testing purposes.</p> <div class="codehilite"><pre><span></span><code><span class="n">validation_images</span> <span class="o">=</span> <span class="p">[]</span> <span class="n">validation_labels</span> <span class="o">=</span> <span class="p">[]</span> <span class="k">for</span> <span class="n">image</span><span class="p">,</span> <span class="n">label</span> <span class="ow">in</span> <span class="n">validation_ds</span><span class="o">.</span><span class="n">take</span><span class="p">(</span><span class="mi">100</span><span class="p">):</span> <span class="n">image</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">resize</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="p">(</span><span class="mi">224</span><span class="p">,</span> <span class="mi">224</span><span class="p">))</span> <span class="n">validation_images</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">image</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span> <span class="n">validation_labels</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">label</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span> <span class="n">validation_images</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">validation_images</span><span class="p">)</span> <span class="n">validation_labels</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">validation_labels</span><span class="p">)</span> <span class="n">validation_images</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">validation_labels</span><span class="o">.</span><span class="n">shape</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>((100, 224, 224, 3), (100,)) </code></pre></div> </div> <p>Now we write our visualization utilities.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">plot_images</span><span class="p">(</span><span class="n">images</span><span class="p">,</span> <span class="n">labels</span><span class="p">):</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">20</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span> <span class="n">columns</span> <span class="o">=</span> <span class="mi">5</span> <span class="k">for</span> <span class="p">(</span><span class="n">i</span><span class="p">,</span> <span class="n">image</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">images</span><span class="p">):</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplot</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">images</span><span class="p">)</span> <span class="o">//</span> <span class="n">columns</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="n">columns</span><span class="p">,</span> <span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="k">if</span> <span class="n">i</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span> <span class="n">ax</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">"Query Image</span><span class="se">\n</span><span class="s2">"</span> <span class="o">+</span> <span class="s2">"Label: </span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">labels</span><span class="p">[</span><span class="n">i</span><span class="p">]))</span> <span class="k">else</span><span class="p">:</span> <span class="n">ax</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">"Similar Image # "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">i</span><span class="p">)</span> <span class="o">+</span> <span class="s2">"</span><span class="se">\n</span><span class="s2">Label: </span><span class="si">{}</span><span class="s2">"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">labels</span><span class="p">[</span><span class="n">i</span><span class="p">]))</span> <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">image</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">"int"</span><span class="p">))</span> <span class="n">plt</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">"off"</span><span class="p">)</span> <span class="k">def</span> <span class="nf">visualize_lsh</span><span class="p">(</span><span class="n">lsh_class</span><span class="p">):</span> <span class="n">idx</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">validation_images</span><span class="p">))</span> <span class="n">image</span> <span class="o">=</span> <span class="n">validation_images</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span> <span class="n">label</span> <span class="o">=</span> <span class="n">validation_labels</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span> <span class="n">results</span> <span class="o">=</span> <span class="n">lsh_class</span><span class="o">.</span><span class="n">query</span><span class="p">(</span><span class="n">image</span><span class="p">)</span> <span class="n">candidates</span> <span class="o">=</span> <span class="p">[]</span> <span class="n">labels</span> <span class="o">=</span> <span class="p">[]</span> <span class="n">overlaps</span> <span class="o">=</span> <span class="p">[]</span> <span class="k">for</span> <span class="n">idx</span><span class="p">,</span> <span class="n">r</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="nb">sorted</span><span class="p">(</span><span class="n">results</span><span class="p">,</span> <span class="n">key</span><span class="o">=</span><span class="n">results</span><span class="o">.</span><span class="n">get</span><span class="p">,</span> <span class="n">reverse</span><span class="o">=</span><span class="kc">True</span><span class="p">)):</span> <span class="k">if</span> <span class="n">idx</span> <span class="o">==</span> <span class="mi">4</span><span class="p">:</span> <span class="k">break</span> <span class="n">image_id</span><span class="p">,</span> <span class="n">label</span> <span class="o">=</span> <span class="n">r</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">"_"</span><span class="p">)[</span><span class="mi">0</span><span class="p">],</span> <span class="n">r</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">"_"</span><span class="p">)[</span><span class="mi">1</span><span class="p">]</span> <span class="n">candidates</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">images</span><span class="p">[</span><span class="nb">int</span><span class="p">(</span><span class="n">image_id</span><span class="p">)])</span> <span class="n">labels</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">label</span><span class="p">)</span> <span class="n">overlaps</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">results</span><span class="p">[</span><span class="n">r</span><span class="p">])</span> <span class="n">candidates</span><span class="o">.</span><span class="n">insert</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">image</span><span class="p">)</span> <span class="n">labels</span><span class="o">.</span><span class="n">insert</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">label</span><span class="p">)</span> <span class="n">plot_images</span><span class="p">(</span><span class="n">candidates</span><span class="p">,</span> <span class="n">labels</span><span class="p">)</span> </code></pre></div> <h3 id="nontrt-model">Non-TRT model</h3> <div class="codehilite"><pre><span></span><code><span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">5</span><span class="p">):</span> <span class="n">visualize_lsh</span><span class="p">(</span><span class="n">lsh_builder</span><span class="p">)</span> <span class="n">visualize_lsh</span><span class="p">(</span><span class="n">lsh_builder</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Matches: 507 Matches: 554 Matches: 438 Matches: 370 Matches: 407 Matches: 306 </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/near_dup_search/near_dup_search_41_1.png" /></p> <p><img alt="png" src="/img/examples/vision/near_dup_search/near_dup_search_41_2.png" /></p> <p><img alt="png" src="/img/examples/vision/near_dup_search/near_dup_search_41_3.png" /></p> <p><img alt="png" src="/img/examples/vision/near_dup_search/near_dup_search_41_4.png" /></p> <p><img alt="png" src="/img/examples/vision/near_dup_search/near_dup_search_41_5.png" /></p> <p><img alt="png" src="/img/examples/vision/near_dup_search/near_dup_search_41_6.png" /></p> <h3 id="trt-model">TRT model</h3> <div class="codehilite"><pre><span></span><code><span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">5</span><span class="p">):</span> <span class="n">visualize_lsh</span><span class="p">(</span><span class="n">lsh_builder_trt</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Matches: 458 Matches: 181 Matches: 280 Matches: 280 Matches: 503 </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/near_dup_search/near_dup_search_43_1.png" /></p> <p><img alt="png" src="/img/examples/vision/near_dup_search/near_dup_search_43_2.png" /></p> <p><img alt="png" src="/img/examples/vision/near_dup_search/near_dup_search_43_3.png" /></p> <p><img alt="png" src="/img/examples/vision/near_dup_search/near_dup_search_43_4.png" /></p> <p><img alt="png" src="/img/examples/vision/near_dup_search/near_dup_search_43_5.png" /></p> <p>As you may have noticed, there are a couple of incorrect results. This can be mitigated in a few ways:</p> <ul> <li>Better models for generating the initial embeddings especially for noisy samples. We can use techniques like <a href="https://arxiv.org/abs/1801.07698">ArcFace</a>, <a href="https://arxiv.org/abs/2004.11362">Supervised Contrastive Learning</a>, etc. that implicitly encourage better learning of representations for retrieval purposes.</li> <li>The trade-off between the number of tables and the reduction dimensionality is crucial and helps set the right recall required for your application.</li> </ul> <hr /> <h2 id="benchmarking-query-performance">Benchmarking query performance</h2> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">benchmark</span><span class="p">(</span><span class="n">lsh_class</span><span class="p">):</span> <span class="n">warmup</span><span class="p">()</span> <span class="n">start_time</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1000</span><span class="p">):</span> <span class="n">image</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">((</span><span class="mi">1</span><span class="p">,</span> <span class="mi">224</span><span class="p">,</span> <span class="mi">224</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">"float32"</span><span class="p">)</span> <span class="n">_</span> <span class="o">=</span> <span class="n">lsh_class</span><span class="o">.</span><span class="n">query</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="n">verbose</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> <span class="n">end_time</span> <span class="o">=</span> <span class="n">time</span><span class="o">.</span><span class="n">time</span><span class="p">()</span> <span class="o">-</span> <span class="n">start_time</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Time taken: </span><span class="si">{</span><span class="n">end_time</span><span class="si">:</span><span class="s2">.3f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> <span class="n">benchmark</span><span class="p">(</span><span class="n">lsh_builder</span><span class="p">)</span> <span class="n">benchmark</span><span class="p">(</span><span class="n">lsh_builder_trt</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Time taken: 54.359 Time taken: 13.963 </code></pre></div> </div> <p>We can immediately notice a stark difference between the query performance of the two models.</p> <hr /> <h2 id="final-remarks">Final remarks</h2> <p>In this example, we explored the TensorRT framework from NVIDIA for optimizing our model. It's best suited for GPU-based inference servers. There are other choices for such frameworks that cater to different hardware platforms:</p> <ul> <li><a href="https://www.tensorflow.org/lite">TensorFlow Lite</a> for mobile and edge devices.</li> <li><a href="hhttps://onnx.ai/">ONNX</a> for commodity CPU-based servers.</li> <li><a href="https://tvm.apache.org/">Apache TVM</a>, compiler for machine learning models covering various platforms.</li> </ul> <p>Here are a few resources you might want to check out to learn more about applications based on vector similary search in general:</p> <ul> <li><a href="http://ann-benchmarks.com/">ANN Benchmarks</a></li> <li><a href="https://arxiv.org/abs/1908.10396">Accelerating Large-Scale Inference with Anisotropic Vector Quantization(ScaNN)</a></li> <li><a href="https://arxiv.org/abs/1806.03198">Spreading vectors for similarity search</a></li> <li><a href="https://cloud.google.com/architecture/building-real-time-embeddings-similarity-matching-system">Building a real-time embeddings similarity matching system</a></li> </ul> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#nearduplicate-image-search'>Near-duplicate image search</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#setup'>Setup</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#imports'>Imports</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#load-the-dataset-and-create-a-training-set-of-1000-images'>Load the dataset and create a training set of 1,000 images</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#load-a-pretrained-model'>Load a pre-trained model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#create-an-embedding-model'>Create an embedding model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#hashing-utilities'>Hashing utilities</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#query-utilities'>Query utilities</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#create-lsh-tables'>Create LSH tables</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#optimize-the-model-with-tensorrt'>Optimize the model with TensorRT</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#build-lsh-tables-with-optimized-model'>Build LSH tables with optimized model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#visualize-results-on-validation-images'>Visualize results on validation images</a> </div> <div class='k-outline-depth-3'> <a href='#nontrt-model'>Non-TRT model</a> </div> <div class='k-outline-depth-3'> <a href='#trt-model'>TRT model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#benchmarking-query-performance'>Benchmarking query performance</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#final-remarks'>Final remarks</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>