CINXE.COM
Segment Anything Model with 🤗Transformers
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/vision/sam/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Segment Anything Model with 🤗Transformers"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Segment Anything Model with 🤗Transformers"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Segment Anything Model with 🤗Transformers</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink active" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink2" href="/examples/vision/image_classification_from_scratch/">Image classification from scratch</a> <a class="nav-sublink2" href="/examples/vision/mnist_convnet/">Simple MNIST convnet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_efficientnet_fine_tuning/">Image classification via fine-tuning with EfficientNet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_with_vision_transformer/">Image classification with Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/attention_mil_classification/">Classification using Attention-based Deep Multiple Instance Learning</a> <a class="nav-sublink2" href="/examples/vision/mlp_image_classification/">Image classification with modern MLP models</a> <a class="nav-sublink2" href="/examples/vision/mobilevit/">A mobile-friendly Transformer-based model for image classification</a> <a class="nav-sublink2" href="/examples/vision/xray_classification_with_tpus/">Pneumonia Classification on TPU</a> <a class="nav-sublink2" href="/examples/vision/cct/">Compact Convolutional Transformers</a> <a class="nav-sublink2" href="/examples/vision/convmixer/">Image classification with ConvMixer</a> <a class="nav-sublink2" href="/examples/vision/eanet/">Image classification with EANet (External Attention Transformer)</a> <a class="nav-sublink2" href="/examples/vision/involution/">Involutional neural networks</a> <a class="nav-sublink2" href="/examples/vision/perceiver_image_classification/">Image classification with Perceiver</a> <a class="nav-sublink2" href="/examples/vision/reptile/">Few-Shot learning with Reptile</a> <a class="nav-sublink2" href="/examples/vision/semisupervised_simclr/">Semi-supervised image classification using contrastive pretraining with SimCLR</a> <a class="nav-sublink2" href="/examples/vision/swin_transformers/">Image classification with Swin Transformers</a> <a class="nav-sublink2" href="/examples/vision/vit_small_ds/">Train a Vision Transformer on small datasets</a> <a class="nav-sublink2" href="/examples/vision/shiftvit/">A Vision Transformer without Attention</a> <a class="nav-sublink2" href="/examples/vision/image_classification_using_global_context_vision_transformer/">Image Classification using Global Context Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/oxford_pets_image_segmentation/">Image segmentation with a U-Net-like architecture</a> <a class="nav-sublink2" href="/examples/vision/deeplabv3_plus/">Multiclass semantic segmentation using DeepLabV3+</a> <a class="nav-sublink2" href="/examples/vision/basnet_segmentation/">Highly accurate boundaries segmentation using BASNet</a> <a class="nav-sublink2" href="/examples/vision/fully_convolutional_network/">Image Segmentation using Composable Fully-Convolutional Networks</a> <a class="nav-sublink2" href="/examples/vision/retinanet/">Object Detection with RetinaNet</a> <a class="nav-sublink2" href="/examples/vision/keypoint_detection/">Keypoint Detection with Transfer Learning</a> <a class="nav-sublink2" href="/examples/vision/object_detection_using_vision_transformer/">Object detection with Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/3D_image_classification/">3D image classification from CT scans</a> <a class="nav-sublink2" href="/examples/vision/depth_estimation/">Monocular depth estimation</a> <a class="nav-sublink2" href="/examples/vision/nerf/">3D volumetric rendering with NeRF</a> <a class="nav-sublink2" href="/examples/vision/pointnet_segmentation/">Point cloud segmentation with PointNet</a> <a class="nav-sublink2" href="/examples/vision/pointnet/">Point cloud classification</a> <a class="nav-sublink2" href="/examples/vision/captcha_ocr/">OCR model for reading Captchas</a> <a class="nav-sublink2" href="/examples/vision/handwriting_recognition/">Handwriting recognition</a> <a class="nav-sublink2" href="/examples/vision/autoencoder/">Convolutional autoencoder for image denoising</a> <a class="nav-sublink2" href="/examples/vision/mirnet/">Low-light image enhancement using MIRNet</a> <a class="nav-sublink2" href="/examples/vision/super_resolution_sub_pixel/">Image Super-Resolution using an Efficient Sub-Pixel CNN</a> <a class="nav-sublink2" href="/examples/vision/edsr/">Enhanced Deep Residual Networks for single-image super-resolution</a> <a class="nav-sublink2" href="/examples/vision/zero_dce/">Zero-DCE for low-light image enhancement</a> <a class="nav-sublink2" href="/examples/vision/cutmix/">CutMix data augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/mixup/">MixUp augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/randaugment/">RandAugment for Image Classification for Improved Robustness</a> <a class="nav-sublink2" href="/examples/vision/image_captioning/">Image captioning</a> <a class="nav-sublink2" href="/examples/vision/nl_image_search/">Natural language image search with a Dual Encoder</a> <a class="nav-sublink2" href="/examples/vision/visualizing_what_convnets_learn/">Visualizing what convnets learn</a> <a class="nav-sublink2" href="/examples/vision/integrated_gradients/">Model interpretability with Integrated Gradients</a> <a class="nav-sublink2" href="/examples/vision/probing_vits/">Investigating Vision Transformer representations</a> <a class="nav-sublink2" href="/examples/vision/grad_cam/">Grad-CAM class activation visualization</a> <a class="nav-sublink2" href="/examples/vision/near_dup_search/">Near-duplicate image search</a> <a class="nav-sublink2" href="/examples/vision/semantic_image_clustering/">Semantic Image Clustering</a> <a class="nav-sublink2" href="/examples/vision/siamese_contrastive/">Image similarity estimation using a Siamese Network with a contrastive loss</a> <a class="nav-sublink2" href="/examples/vision/siamese_network/">Image similarity estimation using a Siamese Network with a triplet loss</a> <a class="nav-sublink2" href="/examples/vision/metric_learning/">Metric learning for image similarity search</a> <a class="nav-sublink2" href="/examples/vision/metric_learning_tf_similarity/">Metric learning for image similarity search using TensorFlow Similarity</a> <a class="nav-sublink2" href="/examples/vision/nnclr/">Self-supervised contrastive learning with NNCLR</a> <a class="nav-sublink2" href="/examples/vision/video_classification/">Video Classification with a CNN-RNN Architecture</a> <a class="nav-sublink2" href="/examples/vision/conv_lstm/">Next-Frame Video Prediction with Convolutional LSTMs</a> <a class="nav-sublink2" href="/examples/vision/video_transformers/">Video Classification with Transformers</a> <a class="nav-sublink2" href="/examples/vision/vivit/">Video Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/bit/">Image Classification using BigTransfer (BiT)</a> <a class="nav-sublink2" href="/examples/vision/gradient_centralization/">Gradient Centralization for Better Training Performance</a> <a class="nav-sublink2" href="/examples/vision/token_learner/">Learning to tokenize in Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/knowledge_distillation/">Knowledge Distillation</a> <a class="nav-sublink2" href="/examples/vision/fixres/">FixRes: Fixing train-test resolution discrepancy</a> <a class="nav-sublink2" href="/examples/vision/cait/">Class Attention Image Transformers with LayerScale</a> <a class="nav-sublink2" href="/examples/vision/patch_convnet/">Augmenting convnets with aggregated attention</a> <a class="nav-sublink2" href="/examples/vision/learnable_resizer/">Learning to Resize</a> <a class="nav-sublink2" href="/examples/vision/adamatch/">Semi-supervision and domain adaptation with AdaMatch</a> <a class="nav-sublink2" href="/examples/vision/barlow_twins/">Barlow Twins for Contrastive SSL</a> <a class="nav-sublink2" href="/examples/vision/consistency_training/">Consistency training with supervision</a> <a class="nav-sublink2" href="/examples/vision/deit/">Distilling Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/focal_modulation_network/">Focal Modulation: A replacement for Self-Attention</a> <a class="nav-sublink2" href="/examples/vision/forwardforward/">Using the Forward-Forward Algorithm for Image Classification</a> <a class="nav-sublink2" href="/examples/vision/masked_image_modeling/">Masked image modeling with Autoencoders</a> <a class="nav-sublink2 active" href="/examples/vision/sam/">Segment Anything Model with 🤗Transformers</a> <a class="nav-sublink2" href="/examples/vision/segformer/">Semantic segmentation with SegFormer and Hugging Face Transformers</a> <a class="nav-sublink2" href="/examples/vision/simsiam/">Self-supervised contrastive learning with SimSiam</a> <a class="nav-sublink2" href="/examples/vision/supervised-contrastive-learning/">Supervised Contrastive Learning</a> <a class="nav-sublink2" href="/examples/vision/temporal_latent_bottleneck/">When Recurrence meets Transformers</a> <a class="nav-sublink2" href="/examples/vision/yolov8/">Efficient Object Detection with YOLOV8 and KerasCV</a> <a class="nav-sublink" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparameter Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> <a class="nav-link" href="/keras_cv/" role="tab" aria-selected="">KerasCV: Computer Vision Workflows</a> <a class="nav-link" href="/keras_nlp/" role="tab" aria-selected="">KerasNLP: Natural Language Workflows</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/vision/'>Computer Vision</a> / Segment Anything Model with 🤗Transformers </div> <div class='k-content'> <h1 id="segment-anything-model-with-🤗transformers">Segment Anything Model with 🤗Transformers</h1> <p><strong>Authors:</strong> <a href="https://twitter.com/mervenoyann">Merve Noyan</a> & <a href="https://twitter.com/RisingSayak">Sayak Paul</a><br> <strong>Date created:</strong> 2023/07/11<br> <strong>Last modified:</strong> 2023/07/11<br> <strong>Description:</strong> Fine-tuning Segment Anything Model using Keras and 🤗 Transformers.</p> <div class='example_version_banner keras_2'>ⓘ This example uses Keras 2</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/vision/ipynb/sam.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/examples/vision/sam.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="introduction">Introduction</h2> <p>Large language models (LLMs) make it easy for the end users to apply them to various applications through "prompting". For example if we wanted an LLM to predict the sentiment of the following sentence – "That movie was amazing, I thoroughly enjoyed it" – we'd do prompt the LLM with something like:</p> <blockquote> <p>What's the sentiment of the following sentence: "That movie was amazing, I thoroughly enjoyed it"?</p> </blockquote> <p>In return, the LLM would return sentiment token.</p> <p>But when it comes to visual recognition tasks, how can we engineer "visual" cues to prompt foundation vision models? For example, we could have an input image and prompt the model with bounding box on that image and ask it to perform segmentation. The bounding box would serve as our visual prompt here.</p> <p>In the <a href="https://segment-anything.com/">Segment Anything Model</a> (dubbed as SAM), researchers from Meta extended the space of language prompting to visual prompting. SAM is capable of performing zero-shot segmentation with a prompt input, inspired by large language models. The prompt here can be a set of foreground/background points, free text, a box or a mask. There are many downstream segmentation tasks, including semantic segmentation and edge detection. The goal of SAM is to enable all of these downstream segmentation tasks through prompting.</p> <p>In this example, we'll learn how to use the SAM model from 🤗 Transformers for performing inference and fine-tuning.</p> <hr /> <h2 id="installation">Installation</h2> <div class="codehilite"><pre><span></span><code><span class="err">!!</span><span class="n">pip</span> <span class="n">install</span> <span class="o">-</span><span class="n">q</span> <span class="n">git</span><span class="o">+</span><span class="n">https</span><span class="p">:</span><span class="o">//</span><span class="n">github</span><span class="o">.</span><span class="n">com</span><span class="o">/</span><span class="n">huggingface</span><span class="o">/</span><span class="n">transformers</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>[] </code></pre></div> </div> <p>Let's import everything we need for this example.</p> <div class="codehilite"><pre><span></span><code><span class="kn">from</span> <span class="nn">tensorflow</span> <span class="kn">import</span> <span class="n">keras</span> <span class="kn">from</span> <span class="nn">transformers</span> <span class="kn">import</span> <span class="n">TFSamModel</span><span class="p">,</span> <span class="n">SamProcessor</span> <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> <span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="nn">tf</span> <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span> <span class="kn">from</span> <span class="nn">tensorflow.python.ops.numpy_ops</span> <span class="kn">import</span> <span class="n">np_config</span> <span class="kn">from</span> <span class="nn">PIL</span> <span class="kn">import</span> <span class="n">Image</span> <span class="kn">import</span> <span class="nn">requests</span> <span class="kn">import</span> <span class="nn">glob</span> <span class="kn">import</span> <span class="nn">os</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>/Users/mervenoyan/miniforge3/envs/py310/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm </code></pre></div> </div> <hr /> <h2 id="sam-in-a-few-words">SAM in a few words</h2> <p>SAM has the following components:</p> <table> <thead> <tr> <th style="text-align: center;"><img alt="" src="https://imgur.com/oLfdwuB.png" /></th> </tr> </thead> <tbody> <tr> <td style="text-align: center;">Image taken from the official <a href="https://ai.facebook.com/blog/segment-anything-foundation-model-image-segmentation/">SAM blog post</a></td> </tr> </tbody> </table> <p>The image encoder is responsible for computing image embeddings. When interacting with SAM, we compute the image embedding one time (as the image encoder is heavy) and then reuse it with different prompts mentioned above (points, bounding boxes, masks).</p> <p>Points and boxes (so-called sparse prompts) go through a lightweight prompt encoder, while masks (dense prompts) go through a convolutional layer. We couple the image embedding extracted from the image encoder and the prompt embedding and both go to a lightweight mask decoder. The decoder is responsible for predicting the mask.</p> <table> <thead> <tr> <th style="text-align: center;"><img alt="" src="https://i.imgur.com/QQ9Ts5T.png" /></th> </tr> </thead> <tbody> <tr> <td style="text-align: center;">Figure taken from the <a href="https://arxiv.org/abs/2304.02643">SAM paper</a></td> </tr> </tbody> </table> <p>SAM was pre-trained to predict a <em>valid</em> mask for any acceptable prompt. This requirement allows SAM to output a valid mask even when the prompt is ambiguous to understand – this makes SAM ambiguity-aware. Moreover, SAM predicts multiple masks for a single prompt.</p> <p>We highly encourage you to check out the <a href="https://arxiv.org/abs/2304.02643">SAM paper</a> and the <a href="https://ai.facebook.com/blog/segment-anything-foundation-model-image-segmentation/">blog post</a> to learn more about the additional details of SAM and the dataset used to pre-trained it.</p> <hr /> <h2 id="running-inference-with-sam">Running inference with SAM</h2> <p>There are three checkpoints for SAM:</p> <ul> <li><a href="https://huggingface.co/facebook/sam-vit-base">sam-vit-base</a></li> <li><a href="https://huggingface.co/facebook/sam-vit-large">sam-vit-large</a></li> <li><a href="https://huggingface.co/facebook/sam-vit-huge">sam-vit-huge</a>.</li> </ul> <p>We load <code>sam-vit-base</code> in <a href="https://huggingface.co/docs/transformers/main/model_doc/sam#transformers.TFSamModel"><code>TFSamModel</code></a>. We also need <code>SamProcessor</code>for the associated checkpoint.</p> <div class="codehilite"><pre><span></span><code><span class="n">model</span> <span class="o">=</span> <span class="n">TFSamModel</span><span class="o">.</span><span class="n">from_pretrained</span><span class="p">(</span><span class="s2">"facebook/sam-vit-base"</span><span class="p">)</span> <span class="n">processor</span> <span class="o">=</span> <span class="n">SamProcessor</span><span class="o">.</span><span class="n">from_pretrained</span><span class="p">(</span><span class="s2">"facebook/sam-vit-base"</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>All model checkpoint layers were used when initializing TFSamModel. </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>All the layers of TFSamModel were initialized from the model checkpoint at facebook/sam-vit-base. If your task is similar to the task the model of the checkpoint was trained on, you can already use TFSamModel for predictions without further training. </code></pre></div> </div> <p>Next, we write some utility functions for visualization. Most of these functions are taken from <a href="https://github.com/huggingface/notebooks/blob/main/examples/segment_anything.ipynb">this notebook</a>.</p> <div class="codehilite"><pre><span></span><code><span class="n">np_config</span><span class="o">.</span><span class="n">enable_numpy_behavior</span><span class="p">()</span> <span class="k">def</span> <span class="nf">show_mask</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">ax</span><span class="p">,</span> <span class="n">random_color</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span> <span class="k">if</span> <span class="n">random_color</span><span class="p">:</span> <span class="n">color</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">([</span><span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">(</span><span class="mi">3</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mf">0.6</span><span class="p">])],</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="k">else</span><span class="p">:</span> <span class="n">color</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mi">30</span> <span class="o">/</span> <span class="mi">255</span><span class="p">,</span> <span class="mi">144</span> <span class="o">/</span> <span class="mi">255</span><span class="p">,</span> <span class="mi">255</span> <span class="o">/</span> <span class="mi">255</span><span class="p">,</span> <span class="mf">0.6</span><span class="p">])</span> <span class="n">h</span><span class="p">,</span> <span class="n">w</span> <span class="o">=</span> <span class="n">mask</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">2</span><span class="p">:]</span> <span class="n">mask_image</span> <span class="o">=</span> <span class="n">mask</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">h</span><span class="p">,</span> <span class="n">w</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">color</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="n">ax</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">mask_image</span><span class="p">)</span> <span class="k">def</span> <span class="nf">show_box</span><span class="p">(</span><span class="n">box</span><span class="p">,</span> <span class="n">ax</span><span class="p">):</span> <span class="n">x0</span><span class="p">,</span> <span class="n">y0</span> <span class="o">=</span> <span class="n">box</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">box</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="n">w</span><span class="p">,</span> <span class="n">h</span> <span class="o">=</span> <span class="n">box</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">-</span> <span class="n">box</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">box</span><span class="p">[</span><span class="mi">3</span><span class="p">]</span> <span class="o">-</span> <span class="n">box</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="n">ax</span><span class="o">.</span><span class="n">add_patch</span><span class="p">(</span> <span class="n">plt</span><span class="o">.</span><span class="n">Rectangle</span><span class="p">((</span><span class="n">x0</span><span class="p">,</span> <span class="n">y0</span><span class="p">),</span> <span class="n">w</span><span class="p">,</span> <span class="n">h</span><span class="p">,</span> <span class="n">edgecolor</span><span class="o">=</span><span class="s2">"green"</span><span class="p">,</span> <span class="n">facecolor</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="n">lw</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span> <span class="p">)</span> <span class="k">def</span> <span class="nf">show_boxes_on_image</span><span class="p">(</span><span class="n">raw_image</span><span class="p">,</span> <span class="n">boxes</span><span class="p">):</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span> <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">raw_image</span><span class="p">)</span> <span class="k">for</span> <span class="n">box</span> <span class="ow">in</span> <span class="n">boxes</span><span class="p">:</span> <span class="n">show_box</span><span class="p">(</span><span class="n">box</span><span class="p">,</span> <span class="n">plt</span><span class="o">.</span><span class="n">gca</span><span class="p">())</span> <span class="n">plt</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">"on"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> <span class="k">def</span> <span class="nf">show_points_on_image</span><span class="p">(</span><span class="n">raw_image</span><span class="p">,</span> <span class="n">input_points</span><span class="p">,</span> <span class="n">input_labels</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span> <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">raw_image</span><span class="p">)</span> <span class="n">input_points</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">input_points</span><span class="p">)</span> <span class="k">if</span> <span class="n">input_labels</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">input_points</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">])</span> <span class="k">else</span><span class="p">:</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">input_labels</span><span class="p">)</span> <span class="n">show_points</span><span class="p">(</span><span class="n">input_points</span><span class="p">,</span> <span class="n">labels</span><span class="p">,</span> <span class="n">plt</span><span class="o">.</span><span class="n">gca</span><span class="p">())</span> <span class="n">plt</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">"on"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> <span class="k">def</span> <span class="nf">show_points_and_boxes_on_image</span><span class="p">(</span><span class="n">raw_image</span><span class="p">,</span> <span class="n">boxes</span><span class="p">,</span> <span class="n">input_points</span><span class="p">,</span> <span class="n">input_labels</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span> <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">raw_image</span><span class="p">)</span> <span class="n">input_points</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">input_points</span><span class="p">)</span> <span class="k">if</span> <span class="n">input_labels</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">input_points</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">])</span> <span class="k">else</span><span class="p">:</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">input_labels</span><span class="p">)</span> <span class="n">show_points</span><span class="p">(</span><span class="n">input_points</span><span class="p">,</span> <span class="n">labels</span><span class="p">,</span> <span class="n">plt</span><span class="o">.</span><span class="n">gca</span><span class="p">())</span> <span class="k">for</span> <span class="n">box</span> <span class="ow">in</span> <span class="n">boxes</span><span class="p">:</span> <span class="n">show_box</span><span class="p">(</span><span class="n">box</span><span class="p">,</span> <span class="n">plt</span><span class="o">.</span><span class="n">gca</span><span class="p">())</span> <span class="n">plt</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">"on"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> <span class="k">def</span> <span class="nf">show_points_and_boxes_on_image</span><span class="p">(</span><span class="n">raw_image</span><span class="p">,</span> <span class="n">boxes</span><span class="p">,</span> <span class="n">input_points</span><span class="p">,</span> <span class="n">input_labels</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span> <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">raw_image</span><span class="p">)</span> <span class="n">input_points</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">input_points</span><span class="p">)</span> <span class="k">if</span> <span class="n">input_labels</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">input_points</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">])</span> <span class="k">else</span><span class="p">:</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">input_labels</span><span class="p">)</span> <span class="n">show_points</span><span class="p">(</span><span class="n">input_points</span><span class="p">,</span> <span class="n">labels</span><span class="p">,</span> <span class="n">plt</span><span class="o">.</span><span class="n">gca</span><span class="p">())</span> <span class="k">for</span> <span class="n">box</span> <span class="ow">in</span> <span class="n">boxes</span><span class="p">:</span> <span class="n">show_box</span><span class="p">(</span><span class="n">box</span><span class="p">,</span> <span class="n">plt</span><span class="o">.</span><span class="n">gca</span><span class="p">())</span> <span class="n">plt</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">"on"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> <span class="k">def</span> <span class="nf">show_points</span><span class="p">(</span><span class="n">coords</span><span class="p">,</span> <span class="n">labels</span><span class="p">,</span> <span class="n">ax</span><span class="p">,</span> <span class="n">marker_size</span><span class="o">=</span><span class="mi">375</span><span class="p">):</span> <span class="n">pos_points</span> <span class="o">=</span> <span class="n">coords</span><span class="p">[</span><span class="n">labels</span> <span class="o">==</span> <span class="mi">1</span><span class="p">]</span> <span class="n">neg_points</span> <span class="o">=</span> <span class="n">coords</span><span class="p">[</span><span class="n">labels</span> <span class="o">==</span> <span class="mi">0</span><span class="p">]</span> <span class="n">ax</span><span class="o">.</span><span class="n">scatter</span><span class="p">(</span> <span class="n">pos_points</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">pos_points</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">color</span><span class="o">=</span><span class="s2">"green"</span><span class="p">,</span> <span class="n">marker</span><span class="o">=</span><span class="s2">"*"</span><span class="p">,</span> <span class="n">s</span><span class="o">=</span><span class="n">marker_size</span><span class="p">,</span> <span class="n">edgecolor</span><span class="o">=</span><span class="s2">"white"</span><span class="p">,</span> <span class="n">linewidth</span><span class="o">=</span><span class="mf">1.25</span><span class="p">,</span> <span class="p">)</span> <span class="n">ax</span><span class="o">.</span><span class="n">scatter</span><span class="p">(</span> <span class="n">neg_points</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span> <span class="n">neg_points</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">],</span> <span class="n">color</span><span class="o">=</span><span class="s2">"red"</span><span class="p">,</span> <span class="n">marker</span><span class="o">=</span><span class="s2">"*"</span><span class="p">,</span> <span class="n">s</span><span class="o">=</span><span class="n">marker_size</span><span class="p">,</span> <span class="n">edgecolor</span><span class="o">=</span><span class="s2">"white"</span><span class="p">,</span> <span class="n">linewidth</span><span class="o">=</span><span class="mf">1.25</span><span class="p">,</span> <span class="p">)</span> <span class="k">def</span> <span class="nf">show_masks_on_image</span><span class="p">(</span><span class="n">raw_image</span><span class="p">,</span> <span class="n">masks</span><span class="p">,</span> <span class="n">scores</span><span class="p">):</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">masks</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">==</span> <span class="mi">4</span><span class="p">:</span> <span class="n">final_masks</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="n">masks</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="k">if</span> <span class="n">scores</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span> <span class="n">final_scores</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="n">scores</span><span class="p">)</span> <span class="n">nb_predictions</span> <span class="o">=</span> <span class="n">scores</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="n">fig</span><span class="p">,</span> <span class="n">axes</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">nb_predictions</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">15</span><span class="p">,</span> <span class="mi">15</span><span class="p">))</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">score</span><span class="p">)</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="n">final_masks</span><span class="p">,</span> <span class="n">final_scores</span><span class="p">)):</span> <span class="n">mask</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">stop_gradient</span><span class="p">(</span><span class="n">mask</span><span class="p">)</span> <span class="n">axes</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">raw_image</span><span class="p">))</span> <span class="n">show_mask</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">axes</span><span class="p">[</span><span class="n">i</span><span class="p">])</span> <span class="n">axes</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">title</span><span class="o">.</span><span class="n">set_text</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Mask </span><span class="si">{</span><span class="n">i</span><span class="o">+</span><span class="mi">1</span><span class="si">}</span><span class="s2">, Score: </span><span class="si">{</span><span class="n">score</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span><span class="si">:</span><span class="s2">.3f</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> <span class="n">axes</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">"off"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> </code></pre></div> <p>We will segment a car image using a point prompt. Make sure to set <code>return_tensors</code> to <code>tf</code> when calling the processor.</p> <p>Let's load an image of a car and segment it.</p> <div class="codehilite"><pre><span></span><code><span class="n">img_url</span> <span class="o">=</span> <span class="s2">"https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"</span> <span class="n">raw_image</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="n">requests</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">img_url</span><span class="p">,</span> <span class="n">stream</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span><span class="o">.</span><span class="n">raw</span><span class="p">)</span><span class="o">.</span><span class="n">convert</span><span class="p">(</span><span class="s2">"RGB"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">raw_image</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> </code></pre></div> <p><img alt="png" src="/img/examples/vision/sam/sam_14_0.png" /></p> <p>Let's now define a set of points we will use as the prompt.</p> <div class="codehilite"><pre><span></span><code><span class="n">input_points</span> <span class="o">=</span> <span class="p">[[[</span><span class="mi">450</span><span class="p">,</span> <span class="mi">600</span><span class="p">]]]</span> <span class="c1"># Visualize a single point.</span> <span class="n">show_points_on_image</span><span class="p">(</span><span class="n">raw_image</span><span class="p">,</span> <span class="n">input_points</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> </code></pre></div> <p><img alt="png" src="/img/examples/vision/sam/sam_16_0.png" /></p> <p>And segment:</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Preprocess the input image.</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">processor</span><span class="p">(</span><span class="n">raw_image</span><span class="p">,</span> <span class="n">input_points</span><span class="o">=</span><span class="n">input_points</span><span class="p">,</span> <span class="n">return_tensors</span><span class="o">=</span><span class="s2">"tf"</span><span class="p">)</span> <span class="c1"># Predict for segmentation with the prompt.</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="o">**</span><span class="n">inputs</span><span class="p">)</span> </code></pre></div> <p><code>outputs</code> has got two attributes of our interest:</p> <ul> <li><code>outputs.pred_masks</code>: which denotes the predicted masks.</li> <li><code>outputs.iou_scores</code>: which denotes the IoU scores associated with the masks.</li> </ul> <p>Let's post-process the masks and visualize them with their IoU scores:</p> <div class="codehilite"><pre><span></span><code><span class="n">masks</span> <span class="o">=</span> <span class="n">processor</span><span class="o">.</span><span class="n">image_processor</span><span class="o">.</span><span class="n">post_process_masks</span><span class="p">(</span> <span class="n">outputs</span><span class="o">.</span><span class="n">pred_masks</span><span class="p">,</span> <span class="n">inputs</span><span class="p">[</span><span class="s2">"original_sizes"</span><span class="p">],</span> <span class="n">inputs</span><span class="p">[</span><span class="s2">"reshaped_input_sizes"</span><span class="p">],</span> <span class="n">return_tensors</span><span class="o">=</span><span class="s2">"tf"</span><span class="p">,</span> <span class="p">)</span> <span class="n">show_masks_on_image</span><span class="p">(</span><span class="n">raw_image</span><span class="p">,</span> <span class="n">masks</span><span class="p">,</span> <span class="n">outputs</span><span class="o">.</span><span class="n">iou_scores</span><span class="p">)</span> </code></pre></div> <p><img alt="png" src="/img/examples/vision/sam/sam_21_0.png" /></p> <p>And there we go!</p> <p>As can be noticed, all the masks are <em>valid</em> masks for the point prompt we provided.</p> <p>SAM is flexible enough to support different visual prompts and we encourage you to check out <a href="https://github.com/huggingface/notebooks/blob/main/examples/segment_anything.ipynb">this notebook</a> to know more about them!</p> <hr /> <h2 id="finetuning">Fine-tuning</h2> <p>We'll use <a href="https://huggingface.co/datasets/nielsr/breast-cancer">this dataset</a> consisting of breast cancer scans. In the medical imaging domain, being able to segment the cells containing malignancy is an important task.</p> <h3 id="data-preparation">Data preparation</h3> <p>Let's first get the dataset.</p> <div class="codehilite"><pre><span></span><code><span class="n">remote_path</span> <span class="o">=</span> <span class="s2">"https://hf.co/datasets/sayakpaul/sample-datasets/resolve/main/breast-cancer-dataset.tar.gz"</span> <span class="n">dataset_path</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">get_file</span><span class="p">(</span> <span class="s2">"breast-cancer-dataset.tar.gz"</span><span class="p">,</span> <span class="n">remote_path</span><span class="p">,</span> <span class="n">untar</span><span class="o">=</span><span class="kc">True</span> <span class="p">)</span> </code></pre></div> <p>Let's now visualize a sample from the dataset.</p> <p><em>(The <code>show_mask()</code> utility is taken from <a href="https://github.com/NielsRogge/Transformers-Tutorials/blob/master/SAM/Fine_tune_SAM_(segment_anything)_on_a_custom_dataset.ipynb">this notebook</a>)</em></p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">show_mask</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">ax</span><span class="p">,</span> <span class="n">random_color</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span> <span class="k">if</span> <span class="n">random_color</span><span class="p">:</span> <span class="n">color</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">([</span><span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">(</span><span class="mi">3</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mf">0.6</span><span class="p">])],</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="k">else</span><span class="p">:</span> <span class="n">color</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mi">30</span> <span class="o">/</span> <span class="mi">255</span><span class="p">,</span> <span class="mi">144</span> <span class="o">/</span> <span class="mi">255</span><span class="p">,</span> <span class="mi">255</span> <span class="o">/</span> <span class="mi">255</span><span class="p">,</span> <span class="mf">0.6</span><span class="p">])</span> <span class="n">h</span><span class="p">,</span> <span class="n">w</span> <span class="o">=</span> <span class="n">mask</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">2</span><span class="p">:]</span> <span class="n">mask_image</span> <span class="o">=</span> <span class="n">mask</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">h</span><span class="p">,</span> <span class="n">w</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="o">*</span> <span class="n">color</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="n">ax</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">mask_image</span><span class="p">)</span> <span class="c1"># Load all the image and label paths.</span> <span class="n">image_paths</span> <span class="o">=</span> <span class="nb">sorted</span><span class="p">(</span><span class="n">glob</span><span class="o">.</span><span class="n">glob</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">dataset_path</span><span class="p">,</span> <span class="s2">"images/*.png"</span><span class="p">)))</span> <span class="n">label_paths</span> <span class="o">=</span> <span class="nb">sorted</span><span class="p">(</span><span class="n">glob</span><span class="o">.</span><span class="n">glob</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">dataset_path</span><span class="p">,</span> <span class="s2">"labels/*.png"</span><span class="p">)))</span> <span class="c1"># Load the image and label.</span> <span class="n">idx</span> <span class="o">=</span> <span class="mi">15</span> <span class="n">image</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="n">image_paths</span><span class="p">[</span><span class="n">idx</span><span class="p">])</span> <span class="n">label</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="n">label_paths</span><span class="p">[</span><span class="n">idx</span><span class="p">])</span> <span class="n">image</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">image</span><span class="p">)</span> <span class="n">ground_truth_seg</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">label</span><span class="p">)</span> <span class="c1"># Display.</span> <span class="n">fig</span><span class="p">,</span> <span class="n">axes</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">()</span> <span class="n">axes</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">image</span><span class="p">)</span> <span class="n">show_mask</span><span class="p">(</span><span class="n">ground_truth_seg</span><span class="p">,</span> <span class="n">axes</span><span class="p">)</span> <span class="n">axes</span><span class="o">.</span><span class="n">title</span><span class="o">.</span><span class="n">set_text</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Ground truth mask"</span><span class="p">)</span> <span class="n">axes</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">"off"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> <span class="n">tf</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">ground_truth_seg</span><span class="p">)</span> </code></pre></div> <p><img alt="png" src="/img/examples/vision/sam/sam_26_0.png" /></p> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code><tf.Tensor: shape=(2,), dtype=int32, numpy=array([256, 256], dtype=int32)> </code></pre></div> </div> <h3 id="tfdatadataset">Preparing <a href="https://www.tensorflow.org/api_docs/python/tf/data/Dataset"><code>tf.data.Dataset</code></a></h3> <p>We now write a generator class to prepare the images and the segmentation masks using the <code>processor</code> utilized above. We will leverage this generator class to create a <a href="https://www.tensorflow.org/api_docs/python/tf/data/Dataset"><code>tf.data.Dataset</code></a> object for our training set by using <code>tf.data.Dataset.from_generator()</code>. Utilities of this class have been adapted from <a href="https://github.com/NielsRogge/Transformers-Tutorials/blob/master/SAM/Fine_tune_SAM_(segment_anything)_on_a_custom_dataset.ipynb">this notebook</a>.</p> <p>The generator is responsible for yielding the preprocessed images and the segmentation masks, and some other metadata needed by the SAM model.</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">Generator</span><span class="p">:</span> <span class="w"> </span><span class="sd">"""Generator class for processing the images and the masks for SAM fine-tuning."""</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">dataset_path</span><span class="p">,</span> <span class="n">processor</span><span class="p">):</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset_path</span> <span class="o">=</span> <span class="n">dataset_path</span> <span class="bp">self</span><span class="o">.</span><span class="n">image_paths</span> <span class="o">=</span> <span class="nb">sorted</span><span class="p">(</span> <span class="n">glob</span><span class="o">.</span><span class="n">glob</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_path</span><span class="p">,</span> <span class="s2">"images/*.png"</span><span class="p">))</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">label_paths</span> <span class="o">=</span> <span class="nb">sorted</span><span class="p">(</span> <span class="n">glob</span><span class="o">.</span><span class="n">glob</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">dataset_path</span><span class="p">,</span> <span class="s2">"labels/*.png"</span><span class="p">))</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">processor</span> <span class="o">=</span> <span class="n">processor</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="k">for</span> <span class="n">image_path</span><span class="p">,</span> <span class="n">label_path</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">image_paths</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">label_paths</span><span class="p">):</span> <span class="n">image</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">Image</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="n">image_path</span><span class="p">))</span> <span class="n">ground_truth_mask</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">Image</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="n">label_path</span><span class="p">))</span> <span class="c1"># get bounding box prompt</span> <span class="n">prompt</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">get_bounding_box</span><span class="p">(</span><span class="n">ground_truth_mask</span><span class="p">)</span> <span class="c1"># prepare image and prompt for the model</span> <span class="n">inputs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">processor</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="n">input_boxes</span><span class="o">=</span><span class="p">[[</span><span class="n">prompt</span><span class="p">]],</span> <span class="n">return_tensors</span><span class="o">=</span><span class="s2">"np"</span><span class="p">)</span> <span class="c1"># remove batch dimension which the processor adds by default</span> <span class="n">inputs</span> <span class="o">=</span> <span class="p">{</span><span class="n">k</span><span class="p">:</span> <span class="n">v</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="mi">0</span><span class="p">)</span> <span class="k">for</span> <span class="n">k</span><span class="p">,</span> <span class="n">v</span> <span class="ow">in</span> <span class="n">inputs</span><span class="o">.</span><span class="n">items</span><span class="p">()}</span> <span class="c1"># add ground truth segmentation</span> <span class="n">inputs</span><span class="p">[</span><span class="s2">"ground_truth_mask"</span><span class="p">]</span> <span class="o">=</span> <span class="n">ground_truth_mask</span> <span class="k">yield</span> <span class="n">inputs</span> <span class="k">def</span> <span class="nf">get_bounding_box</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">ground_truth_map</span><span class="p">):</span> <span class="c1"># get bounding box from mask</span> <span class="n">y_indices</span><span class="p">,</span> <span class="n">x_indices</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="n">ground_truth_map</span> <span class="o">></span> <span class="mi">0</span><span class="p">)</span> <span class="n">x_min</span><span class="p">,</span> <span class="n">x_max</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">min</span><span class="p">(</span><span class="n">x_indices</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">x_indices</span><span class="p">)</span> <span class="n">y_min</span><span class="p">,</span> <span class="n">y_max</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">min</span><span class="p">(</span><span class="n">y_indices</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">y_indices</span><span class="p">)</span> <span class="c1"># add perturbation to bounding box coordinates</span> <span class="n">H</span><span class="p">,</span> <span class="n">W</span> <span class="o">=</span> <span class="n">ground_truth_map</span><span class="o">.</span><span class="n">shape</span> <span class="n">x_min</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">x_min</span> <span class="o">-</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">20</span><span class="p">))</span> <span class="n">x_max</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">W</span><span class="p">,</span> <span class="n">x_max</span> <span class="o">+</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">20</span><span class="p">))</span> <span class="n">y_min</span> <span class="o">=</span> <span class="nb">max</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">y_min</span> <span class="o">-</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">20</span><span class="p">))</span> <span class="n">y_max</span> <span class="o">=</span> <span class="nb">min</span><span class="p">(</span><span class="n">H</span><span class="p">,</span> <span class="n">y_max</span> <span class="o">+</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">20</span><span class="p">))</span> <span class="n">bbox</span> <span class="o">=</span> <span class="p">[</span><span class="n">x_min</span><span class="p">,</span> <span class="n">y_min</span><span class="p">,</span> <span class="n">x_max</span><span class="p">,</span> <span class="n">y_max</span><span class="p">]</span> <span class="k">return</span> <span class="n">bbox</span> </code></pre></div> <p><code>get_bounding_box()</code> is responsible for turning the ground-truth segmentation maps into bounding boxes. These bounding boxes are fed to SAM as prompts (along with the original images) during fine-tuning and SAM is then trained to predict valid masks.</p> <p>The advantage of this first creating a generator and then using it to create a <a href="https://www.tensorflow.org/api_docs/python/tf/data/Dataset"><code>tf.data.Dataset</code></a> is the flexbility. Sometimes, we may need to use utitlities from other libraries (<a href="https://albumentations.ai/"><code>albumentations</code></a>, for example) which may not come in native TensorFlow implementations. By using this workflow, we can easily accommodate such use case.</p> <p>But the non-TF counterparts might introduce performance bottlenecks, though. However, for our example, it should work just fine.</p> <p>Now, we prepare the <a href="https://www.tensorflow.org/api_docs/python/tf/data/Dataset"><code>tf.data.Dataset</code></a> from our training set.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Define the output signature of the generator class.</span> <span class="n">output_signature</span> <span class="o">=</span> <span class="p">{</span> <span class="s2">"pixel_values"</span><span class="p">:</span> <span class="n">tf</span><span class="o">.</span><span class="n">TensorSpec</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span> <span class="s2">"original_sizes"</span><span class="p">:</span> <span class="n">tf</span><span class="o">.</span><span class="n">TensorSpec</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="kc">None</span><span class="p">,),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">int64</span><span class="p">),</span> <span class="s2">"reshaped_input_sizes"</span><span class="p">:</span> <span class="n">tf</span><span class="o">.</span><span class="n">TensorSpec</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="kc">None</span><span class="p">,),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">int64</span><span class="p">),</span> <span class="s2">"input_boxes"</span><span class="p">:</span> <span class="n">tf</span><span class="o">.</span><span class="n">TensorSpec</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">float64</span><span class="p">),</span> <span class="s2">"ground_truth_mask"</span><span class="p">:</span> <span class="n">tf</span><span class="o">.</span><span class="n">TensorSpec</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">int32</span><span class="p">),</span> <span class="p">}</span> <span class="c1"># Prepare the dataset object.</span> <span class="n">train_dataset_gen</span> <span class="o">=</span> <span class="n">Generator</span><span class="p">(</span><span class="n">dataset_path</span><span class="p">,</span> <span class="n">processor</span><span class="p">)</span> <span class="n">train_ds</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">from_generator</span><span class="p">(</span> <span class="n">train_dataset_gen</span><span class="p">,</span> <span class="n">output_signature</span><span class="o">=</span><span class="n">output_signature</span> <span class="p">)</span> </code></pre></div> <p>Next, we configure the dataset for performance.</p> <div class="codehilite"><pre><span></span><code><span class="n">auto</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">AUTOTUNE</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="mi">2</span> <span class="n">shuffle_buffer</span> <span class="o">=</span> <span class="mi">4</span> <span class="n">train_ds</span> <span class="o">=</span> <span class="p">(</span> <span class="n">train_ds</span><span class="o">.</span><span class="n">cache</span><span class="p">()</span> <span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="n">shuffle_buffer</span><span class="p">)</span> <span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span> <span class="o">.</span><span class="n">prefetch</span><span class="p">(</span><span class="n">buffer_size</span><span class="o">=</span><span class="n">auto</span><span class="p">)</span> <span class="p">)</span> </code></pre></div> <p>Take a single batch of data and inspect the shapes of the elements present inside of it.</p> <div class="codehilite"><pre><span></span><code><span class="n">sample</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="nb">iter</span><span class="p">(</span><span class="n">train_ds</span><span class="p">))</span> <span class="k">for</span> <span class="n">k</span> <span class="ow">in</span> <span class="n">sample</span><span class="p">:</span> <span class="nb">print</span><span class="p">(</span><span class="n">k</span><span class="p">,</span> <span class="n">sample</span><span class="p">[</span><span class="n">k</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">sample</span><span class="p">[</span><span class="n">k</span><span class="p">]</span><span class="o">.</span><span class="n">dtype</span><span class="p">,</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">sample</span><span class="p">[</span><span class="n">k</span><span class="p">],</span> <span class="n">tf</span><span class="o">.</span><span class="n">Tensor</span><span class="p">))</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>pixel_values (2, 3, 1024, 1024) <dtype: 'float32'> True original_sizes (2, 2) <dtype: 'int64'> True reshaped_input_sizes (2, 2) <dtype: 'int64'> True input_boxes (2, 1, 4) <dtype: 'float64'> True ground_truth_mask (2, 256, 256) <dtype: 'int32'> True </code></pre></div> </div> <h3 id="training">Training</h3> <p>We will now write DICE loss. This implementation is based on <a href="https://docs.monai.io/en/stable/losses.html#diceloss">MONAI DICE loss</a>.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">dice_loss</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">,</span> <span class="n">smooth</span><span class="o">=</span><span class="mf">1e-5</span><span class="p">):</span> <span class="n">y_pred</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">y_pred</span><span class="p">)</span> <span class="n">reduce_axis</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">y_pred</span><span class="o">.</span><span class="n">shape</span><span class="p">)))</span> <span class="k">if</span> <span class="n">batch_size</span> <span class="o">></span> <span class="mi">1</span><span class="p">:</span> <span class="c1"># reducing spatial dimensions and batch</span> <span class="n">reduce_axis</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="n">reduce_axis</span> <span class="n">intersection</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">reduce_sum</span><span class="p">(</span><span class="n">y_true</span> <span class="o">*</span> <span class="n">y_pred</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="n">reduce_axis</span><span class="p">)</span> <span class="n">y_true_sq</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">math</span><span class="o">.</span><span class="n">pow</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span> <span class="n">y_pred_sq</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">math</span><span class="o">.</span><span class="n">pow</span><span class="p">(</span><span class="n">y_pred</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span> <span class="n">ground_o</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">reduce_sum</span><span class="p">(</span><span class="n">y_true_sq</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="n">reduce_axis</span><span class="p">)</span> <span class="n">pred_o</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">reduce_sum</span><span class="p">(</span><span class="n">y_pred_sq</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="n">reduce_axis</span><span class="p">)</span> <span class="n">denominator</span> <span class="o">=</span> <span class="n">ground_o</span> <span class="o">+</span> <span class="n">pred_o</span> <span class="c1"># calculate DICE coefficient</span> <span class="n">loss</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">-</span> <span class="p">(</span><span class="mf">2.0</span> <span class="o">*</span> <span class="n">intersection</span> <span class="o">+</span> <span class="mf">1e-5</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="n">denominator</span> <span class="o">+</span> <span class="mf">1e-5</span><span class="p">)</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">reduce_mean</span><span class="p">(</span><span class="n">loss</span><span class="p">)</span> <span class="k">return</span> <span class="n">loss</span> </code></pre></div> <h2 id="finetuning-sam">Fine-tuning SAM</h2> <p>We will now fine-tune SAM's decoder part. We will freeze the vision encoder and prompt encoder layers.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># initialize SAM model and optimizer</span> <span class="n">sam</span> <span class="o">=</span> <span class="n">TFSamModel</span><span class="o">.</span><span class="n">from_pretrained</span><span class="p">(</span><span class="s2">"facebook/sam-vit-base"</span><span class="p">)</span> <span class="n">optimizer</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="mf">1e-5</span><span class="p">)</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="n">sam</span><span class="o">.</span><span class="n">layers</span><span class="p">:</span> <span class="k">if</span> <span class="n">layer</span><span class="o">.</span><span class="n">name</span> <span class="ow">in</span> <span class="p">[</span><span class="s2">"vision_encoder"</span><span class="p">,</span> <span class="s2">"prompt_encoder"</span><span class="p">]:</span> <span class="n">layer</span><span class="o">.</span><span class="n">trainable</span> <span class="o">=</span> <span class="kc">False</span> <span class="nd">@tf</span><span class="o">.</span><span class="n">function</span> <span class="k">def</span> <span class="nf">train_step</span><span class="p">(</span><span class="n">inputs</span><span class="p">):</span> <span class="k">with</span> <span class="n">tf</span><span class="o">.</span><span class="n">GradientTape</span><span class="p">()</span> <span class="k">as</span> <span class="n">tape</span><span class="p">:</span> <span class="c1"># pass inputs to SAM model</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">sam</span><span class="p">(</span> <span class="n">pixel_values</span><span class="o">=</span><span class="n">inputs</span><span class="p">[</span><span class="s2">"pixel_values"</span><span class="p">],</span> <span class="n">input_boxes</span><span class="o">=</span><span class="n">inputs</span><span class="p">[</span><span class="s2">"input_boxes"</span><span class="p">],</span> <span class="n">multimask_output</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="p">)</span> <span class="n">predicted_masks</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="n">outputs</span><span class="o">.</span><span class="n">pred_masks</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="n">ground_truth_masks</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">inputs</span><span class="p">[</span><span class="s2">"ground_truth_mask"</span><span class="p">],</span> <span class="n">tf</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="c1"># calculate loss over predicted and ground truth masks</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">dice_loss</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">ground_truth_masks</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">predicted_masks</span><span class="p">)</span> <span class="c1"># update trainable variables</span> <span class="n">trainable_vars</span> <span class="o">=</span> <span class="n">sam</span><span class="o">.</span><span class="n">trainable_variables</span> <span class="n">grads</span> <span class="o">=</span> <span class="n">tape</span><span class="o">.</span><span class="n">gradient</span><span class="p">(</span><span class="n">loss</span><span class="p">,</span> <span class="n">trainable_vars</span><span class="p">)</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="n">grads</span><span class="p">,</span> <span class="n">trainable_vars</span><span class="p">))</span> <span class="k">return</span> <span class="n">loss</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>All model checkpoint layers were used when initializing TFSamModel. </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>All the layers of TFSamModel were initialized from the model checkpoint at facebook/sam-vit-base. If your task is similar to the task the model of the checkpoint was trained on, you can already use TFSamModel for predictions without further training. WARNING:absl:At this time, the v2.11+ optimizer [`tf.keras.optimizers.Adam`](/api/optimizers/adam#adam-class) runs slowly on M1/M2 Macs, please use the legacy Keras optimizer instead, located at [`tf.keras.optimizers.legacy.Adam`](https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/legacy/Adam). </code></pre></div> </div> <p>We can now run the training for three epochs. We might have a warning about gradients not existing on IoU prediction head of mask decoder, we can safely ignore that.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># run training</span> <span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">3</span><span class="p">):</span> <span class="k">for</span> <span class="n">inputs</span> <span class="ow">in</span> <span class="n">train_ds</span><span class="p">:</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">train_step</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Epoch </span><span class="si">{</span><span class="n">epoch</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="mi">1</span><span class="si">}</span><span class="s2">: Loss = </span><span class="si">{</span><span class="n">loss</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>WARNING:tensorflow:Gradients do not exist for variables ['tf_sam_model_1/mask_decoder/iou_prediction_head/proj_in/kernel:0', 'tf_sam_model_1/mask_decoder/iou_prediction_head/proj_in/bias:0', 'tf_sam_model_1/mask_decoder/iou_prediction_head/proj_out/kernel:0', 'tf_sam_model_1/mask_decoder/iou_prediction_head/proj_out/bias:0', 'tf_sam_model_1/mask_decoder/iou_prediction_head/layers_._0/kernel:0', 'tf_sam_model_1/mask_decoder/iou_prediction_head/layers_._0/bias:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument? WARNING:tensorflow:Gradients do not exist for variables ['tf_sam_model_1/mask_decoder/iou_prediction_head/proj_in/kernel:0', 'tf_sam_model_1/mask_decoder/iou_prediction_head/proj_in/bias:0', 'tf_sam_model_1/mask_decoder/iou_prediction_head/proj_out/kernel:0', 'tf_sam_model_1/mask_decoder/iou_prediction_head/proj_out/bias:0', 'tf_sam_model_1/mask_decoder/iou_prediction_head/layers_._0/kernel:0', 'tf_sam_model_1/mask_decoder/iou_prediction_head/layers_._0/bias:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument? WARNING:tensorflow:Gradients do not exist for variables ['tf_sam_model_1/mask_decoder/iou_prediction_head/proj_in/kernel:0', 'tf_sam_model_1/mask_decoder/iou_prediction_head/proj_in/bias:0', 'tf_sam_model_1/mask_decoder/iou_prediction_head/proj_out/kernel:0', 'tf_sam_model_1/mask_decoder/iou_prediction_head/proj_out/bias:0', 'tf_sam_model_1/mask_decoder/iou_prediction_head/layers_._0/kernel:0', 'tf_sam_model_1/mask_decoder/iou_prediction_head/layers_._0/bias:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument? WARNING:tensorflow:Gradients do not exist for variables ['tf_sam_model_1/mask_decoder/iou_prediction_head/proj_in/kernel:0', 'tf_sam_model_1/mask_decoder/iou_prediction_head/proj_in/bias:0', 'tf_sam_model_1/mask_decoder/iou_prediction_head/proj_out/kernel:0', 'tf_sam_model_1/mask_decoder/iou_prediction_head/proj_out/bias:0', 'tf_sam_model_1/mask_decoder/iou_prediction_head/layers_._0/kernel:0', 'tf_sam_model_1/mask_decoder/iou_prediction_head/layers_._0/bias:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument? WARNING:tensorflow:Gradients do not exist for variables ['tf_sam_model_1/mask_decoder/iou_prediction_head/proj_in/kernel:0', 'tf_sam_model_1/mask_decoder/iou_prediction_head/proj_in/bias:0', 'tf_sam_model_1/mask_decoder/iou_prediction_head/proj_out/kernel:0', 'tf_sam_model_1/mask_decoder/iou_prediction_head/proj_out/bias:0', 'tf_sam_model_1/mask_decoder/iou_prediction_head/layers_._0/kernel:0', 'tf_sam_model_1/mask_decoder/iou_prediction_head/layers_._0/bias:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument? WARNING:tensorflow:Gradients do not exist for variables ['tf_sam_model_1/mask_decoder/iou_prediction_head/proj_in/kernel:0', 'tf_sam_model_1/mask_decoder/iou_prediction_head/proj_in/bias:0', 'tf_sam_model_1/mask_decoder/iou_prediction_head/proj_out/kernel:0', 'tf_sam_model_1/mask_decoder/iou_prediction_head/proj_out/bias:0', 'tf_sam_model_1/mask_decoder/iou_prediction_head/layers_._0/kernel:0', 'tf_sam_model_1/mask_decoder/iou_prediction_head/layers_._0/bias:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument? WARNING:tensorflow:Gradients do not exist for variables ['tf_sam_model_1/mask_decoder/iou_prediction_head/proj_in/kernel:0', 'tf_sam_model_1/mask_decoder/iou_prediction_head/proj_in/bias:0', 'tf_sam_model_1/mask_decoder/iou_prediction_head/proj_out/kernel:0', 'tf_sam_model_1/mask_decoder/iou_prediction_head/proj_out/bias:0', 'tf_sam_model_1/mask_decoder/iou_prediction_head/layers_._0/kernel:0', 'tf_sam_model_1/mask_decoder/iou_prediction_head/layers_._0/bias:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument? WARNING:tensorflow:Gradients do not exist for variables ['tf_sam_model_1/mask_decoder/iou_prediction_head/proj_in/kernel:0', 'tf_sam_model_1/mask_decoder/iou_prediction_head/proj_in/bias:0', 'tf_sam_model_1/mask_decoder/iou_prediction_head/proj_out/kernel:0', 'tf_sam_model_1/mask_decoder/iou_prediction_head/proj_out/bias:0', 'tf_sam_model_1/mask_decoder/iou_prediction_head/layers_._0/kernel:0', 'tf_sam_model_1/mask_decoder/iou_prediction_head/layers_._0/bias:0'] when minimizing the loss. If you're using `model.compile()`, did you forget to provide a `loss` argument? Epoch 1: Loss = 0.08322787284851074 Epoch 2: Loss = 0.05677264928817749 Epoch 3: Loss = 0.07764029502868652 </code></pre></div> </div> <h3 id="serialize-the-model">Serialize the model</h3> <p>We serialized the model and pushed for you below. <code>push_to_hub</code> method serializes model, generates a model card and pushes it to Hugging Face Hub, so that other people can load the model using <code>from_pretrained</code> method to infer or further fine-tune. We also need to push the same preprocessor in the repository. Find the model and the preprocessor <a href="https://huggingface.co/merve/sam-finetuned">here</a>.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># sam.push_to_hub("merve/sam-finetuned")</span> <span class="c1"># processor.push_to_hub("merve/sam-finetuned")</span> </code></pre></div> <p>We can now infer with the model.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Load another image for inference.</span> <span class="n">idx</span> <span class="o">=</span> <span class="mi">20</span> <span class="n">raw_image_inference</span> <span class="o">=</span> <span class="n">Image</span><span class="o">.</span><span class="n">open</span><span class="p">(</span><span class="n">image_paths</span><span class="p">[</span><span class="n">idx</span><span class="p">])</span> <span class="c1"># process the image and infer</span> <span class="n">preprocessed_img</span> <span class="o">=</span> <span class="n">processor</span><span class="p">(</span><span class="n">raw_image_inference</span><span class="p">)</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">sam</span><span class="p">(</span><span class="n">preprocessed_img</span><span class="p">)</span> </code></pre></div> <p>Lastly, we can visualize the results.</p> <div class="codehilite"><pre><span></span><code><span class="n">infer_masks</span> <span class="o">=</span> <span class="n">outputs</span><span class="p">[</span><span class="s2">"pred_masks"</span><span class="p">]</span> <span class="n">iou_scores</span> <span class="o">=</span> <span class="n">outputs</span><span class="p">[</span><span class="s2">"iou_scores"</span><span class="p">]</span> <span class="n">show_masks_on_image</span><span class="p">(</span><span class="n">raw_image_inference</span><span class="p">,</span> <span class="n">masks</span><span class="o">=</span><span class="n">infer_masks</span><span class="p">,</span> <span class="n">scores</span><span class="o">=</span><span class="n">iou_scores</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). WARNING:matplotlib.image:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers). </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/sam/sam_48_1.png" /></p> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#segment-anything-model-with-🤗transformers'>Segment Anything Model with 🤗Transformers</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#installation'>Installation</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#sam-in-a-few-words'>SAM in a few words</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#running-inference-with-sam'>Running inference with SAM</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#finetuning'>Fine-tuning</a> </div> <div class='k-outline-depth-3'> <a href='#data-preparation'>Data preparation</a> </div> <div class='k-outline-depth-3'> <a href='#preparing-tfdatadataset'>Preparing <code>tf.data.Dataset</code></a> </div> <div class='k-outline-depth-3'> <a href='#training'>Training</a> </div> <div class='k-outline-depth-3'> <a href='#serialize-the-model'>Serialize the model</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>