CINXE.COM
Semantic Segmentation with KerasHub
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/keras_hub/guides/semantic_segmentation_deeplab_v3/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Semantic Segmentation with KerasHub"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Semantic Segmentation with KerasHub"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Semantic Segmentation with KerasHub</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparam Tuning</a> <a class="nav-link active" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> <a class="nav-sublink" href="/keras_hub/getting_started/">Getting started</a> <a class="nav-sublink active" href="/keras_hub/guides/">Developer guides</a> <a class="nav-sublink2" href="/keras_hub/guides/upload/">Uploading Models</a> <a class="nav-sublink2" href="/keras_hub/guides/stable_diffusion_3_in_keras_hub/">Stable Diffusion 3</a> <a class="nav-sublink2" href="/keras_hub/guides/segment_anything_in_keras_hub/">Segment Anything</a> <a class="nav-sublink2" href="/keras_hub/guides/classification_with_keras_hub/">Image Classification</a> <a class="nav-sublink2 active" href="/keras_hub/guides/semantic_segmentation_deeplab_v3/">Semantic Segmentation</a> <a class="nav-sublink2" href="/keras_hub/guides/transformer_pretraining/">Pretraining a Transformer from scratch</a> <a class="nav-sublink" href="/keras_hub/api/">API documentation</a> <a class="nav-sublink" href="/keras_hub/presets/">Pretrained models list</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/keras_hub/'>KerasHub: Pretrained Models</a> / <a href='/keras_hub/guides/'>Developer guides</a> / Semantic Segmentation with KerasHub </div> <div class='k-content'> <h1 id="semantic-segmentation-with-kerashub">Semantic Segmentation with KerasHub</h1> <p><strong>Authors:</strong> <a href="https://github.com/sachinprasadhs">Sachin Prasad</a>, <a href="https://github.com/divyashreepathihalli">Divyashree Sreepathihalli</a>, <a href="https://github.com/ianstenbit">Ian Stenbit</a><br> <strong>Date created:</strong> 2024/10/11<br> <strong>Last modified:</strong> 2024/10/22<br> <strong>Description:</strong> DeepLabV3 training and inference with KerasHub.</p> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/guides/ipynb/keras_hub/semantic_segmentation_deeplab_v3.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/guides/keras_hub/semantic_segmentation_deeplab_v3.py"><strong>GitHub source</strong></a></p> <p><img alt="" src="https://storage.googleapis.com/keras-hub/getting_started_guide/prof_keras_intermediate.png" /></p> <hr /> <h2 id="background">Background</h2> <p>Semantic segmentation is a type of computer vision task that involves assigning a class label such as "person", "bike", or "background" to each individual pixel of an image, effectively dividing the image into regions that correspond to different object classes or categories.</p> <p><img alt="" src="https://miro.medium.com/v2/resize:fit:4800/format:webp/1*z6ch-2BliDGLIHpOPFY_Sw.png" /></p> <p>KerasHub offers the DeepLabv3, DeepLabv3+, SegFormer, etc., models for semantic segmentation.</p> <p>This guide demonstrates how to fine-tune and use the DeepLabv3+ model, developed by Google for image semantic segmentation with KerasHub. Its architecture combines Atrous convolutions, contextual information aggregation, and powerful backbones to achieve accurate and detailed semantic segmentation.</p> <p>DeepLabv3+ extends DeepLabv3 by adding a simple yet effective decoder module to refine the segmentation results, especially along object boundaries. Both models have achieved state-of-the-art results on a variety of image segmentation benchmarks.</p> <h3 id="references">References</h3> <p><a href="https://arxiv.org/abs/1802.02611">Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation</a> <a href="https://arxiv.org/abs/1706.05587">Rethinking Atrous Convolution for Semantic Image Segmentation</a></p> <hr /> <h2 id="setup-and-imports">Setup and Imports</h2> <p>Let's install the dependencies and import the necessary modules.</p> <p>To run this tutorial, you will need to install the following packages:</p> <ul> <li><code>keras-hub</code></li> <li><code>keras</code></li> </ul> <div class="codehilite"><pre><span></span><code><span class="err">!</span><span class="n">pip</span> <span class="n">install</span> <span class="o">-</span><span class="n">q</span> <span class="o">--</span><span class="n">upgrade</span> <span class="n">keras</span><span class="o">-</span><span class="n">hub</span> <span class="err">!</span><span class="n">pip</span> <span class="n">install</span> <span class="o">-</span><span class="n">q</span> <span class="o">--</span><span class="n">upgrade</span> <span class="n">keras</span> </code></pre></div> <p>After installing <code>keras</code> and <code>keras-hub</code>, set the backend for <code>keras</code>. This guide can be run with any backend (Tensorflow, JAX, PyTorch).</p> <div class="codehilite"><pre><span></span><code><span class="kn">import</span><span class="w"> </span><span class="nn">os</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s2">"KERAS_BACKEND"</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"jax"</span> <span class="kn">import</span><span class="w"> </span><span class="nn">keras</span> <span class="kn">from</span><span class="w"> </span><span class="nn">keras</span><span class="w"> </span><span class="kn">import</span> <span class="n">ops</span> <span class="kn">import</span><span class="w"> </span><span class="nn">keras_hub</span> <span class="kn">import</span><span class="w"> </span><span class="nn">numpy</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">np</span> <span class="kn">import</span><span class="w"> </span><span class="nn">tensorflow</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">tf</span> <span class="kn">import</span><span class="w"> </span><span class="nn">matplotlib.pyplot</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">plt</span> </code></pre></div> <hr /> <h2 id="perform-semantic-segmentation-with-a-pretrained-deeplabv3-model">Perform semantic segmentation with a pretrained DeepLabv3+ model</h2> <p>The highest level API in the KerasHub semantic segmentation API is the <code>keras_hub.models</code> API. This API includes fully pretrained semantic segmentation models, such as <a href="/keras_hub/api/models/deeplab_v3/deeplab_v3_image_segmenter#deeplabv3imagesegmenter-class"><code>keras_hub.models.DeepLabV3ImageSegmenter</code></a>.</p> <p>Let's get started by constructing a DeepLabv3 pretrained on the Pascal VOC dataset. Also, define the preprocessing function for the model to preprocess images and labels. <strong>Note:</strong> By default <code>from_preset()</code> method in KerasHub loads the pretrained task weights with all the classes, 21 classes in this case.</p> <div class="codehilite"><pre><span></span><code><span class="n">model</span> <span class="o">=</span> <span class="n">keras_hub</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">DeepLabV3ImageSegmenter</span><span class="o">.</span><span class="n">from_preset</span><span class="p">(</span> <span class="s2">"deeplab_v3_plus_resnet50_pascalvoc"</span> <span class="p">)</span> <span class="n">image_converter</span> <span class="o">=</span> <span class="n">keras_hub</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">DeepLabV3ImageConverter</span><span class="p">(</span> <span class="n">image_size</span><span class="o">=</span><span class="p">(</span><span class="mi">512</span><span class="p">,</span> <span class="mi">512</span><span class="p">),</span> <span class="n">interpolation</span><span class="o">=</span><span class="s2">"bilinear"</span><span class="p">,</span> <span class="p">)</span> <span class="n">preprocessor</span> <span class="o">=</span> <span class="n">keras_hub</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">DeepLabV3ImageSegmenterPreprocessor</span><span class="p">(</span><span class="n">image_converter</span><span class="p">)</span> </code></pre></div> <p>Let us visualize the results of this pretrained model</p> <div class="codehilite"><pre><span></span><code><span class="n">filepath</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">get_file</span><span class="p">(</span> <span class="n">origin</span><span class="o">=</span><span class="s2">"https://storage.googleapis.com/keras-cv/models/paligemma/cow_beach_1.png"</span> <span class="p">)</span> <span class="n">image</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">load_img</span><span class="p">(</span><span class="n">filepath</span><span class="p">)</span> <span class="n">image</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">image</span><span class="p">)</span> <span class="n">image</span> <span class="o">=</span> <span class="n">preprocessor</span><span class="p">(</span><span class="n">image</span><span class="p">)</span> <span class="n">image</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="n">preds</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">ops</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">image</span><span class="p">),</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">),</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="k">def</span><span class="w"> </span><span class="nf">plot_segmentation</span><span class="p">(</span><span class="n">original_image</span><span class="p">,</span> <span class="n">predicted_mask</span><span class="p">):</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">5</span><span class="p">))</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">original_image</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">/</span> <span class="mi">255</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">"off"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">predicted_mask</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="n">plt</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">"off"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">tight_layout</span><span class="p">()</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> <span class="n">plot_segmentation</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="n">preds</span><span class="p">)</span> </code></pre></div> <p>1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5s/step</p> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> </code></pre></div> </div> <p>1/1 ━━━━━━━━━━━━━━━━━━━━ 5s 5s/step</p> <p><img alt="png" src="/img/guides/semantic_segmentation_deeplab_v3/semantic_segmentation_deeplab_v3_9_3.png" /></p> <hr /> <h2 id="train-a-custom-semantic-segmentation-model">Train a custom semantic segmentation model</h2> <p>In this guide, we'll assemble a full training pipeline for a KerasHub DeepLabV3 semantic segmentation model. This includes data loading, augmentation, training, metric evaluation, and inference!</p> <hr /> <h2 id="download-the-data">Download the data</h2> <p>We download Pascal VOC 2012 dataset with additional annotations provided here <a href="https://ieeexplore.ieee.org/document/6126343">Semantic contours from inverse detectors</a> and split them into train dataset <code>train_ds</code> and <code>eval_ds</code>.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># @title helper functions</span> <span class="kn">import</span><span class="w"> </span><span class="nn">logging</span> <span class="kn">import</span><span class="w"> </span><span class="nn">multiprocessing</span> <span class="kn">from</span><span class="w"> </span><span class="nn">builtins</span><span class="w"> </span><span class="kn">import</span> <span class="nb">open</span> <span class="kn">import</span><span class="w"> </span><span class="nn">os.path</span> <span class="kn">import</span><span class="w"> </span><span class="nn">random</span> <span class="kn">import</span><span class="w"> </span><span class="nn">xml</span> <span class="kn">import</span><span class="w"> </span><span class="nn">tensorflow_datasets</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">tfds</span> <span class="n">VOC_URL</span> <span class="o">=</span> <span class="s2">"https://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar"</span> <span class="n">SBD_URL</span> <span class="o">=</span> <span class="s2">"https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz"</span> <span class="c1"># Note that this list doesn't contain the background class. In the</span> <span class="c1"># classification use case, the label is 0 based (aeroplane -> 0), whereas in</span> <span class="c1"># segmentation use case, the 0 is reserved for background, so aeroplane maps to</span> <span class="c1"># 1.</span> <span class="n">CLASSES</span> <span class="o">=</span> <span class="p">[</span> <span class="s2">"aeroplane"</span><span class="p">,</span> <span class="s2">"bicycle"</span><span class="p">,</span> <span class="s2">"bird"</span><span class="p">,</span> <span class="s2">"boat"</span><span class="p">,</span> <span class="s2">"bottle"</span><span class="p">,</span> <span class="s2">"bus"</span><span class="p">,</span> <span class="s2">"car"</span><span class="p">,</span> <span class="s2">"cat"</span><span class="p">,</span> <span class="s2">"chair"</span><span class="p">,</span> <span class="s2">"cow"</span><span class="p">,</span> <span class="s2">"diningtable"</span><span class="p">,</span> <span class="s2">"dog"</span><span class="p">,</span> <span class="s2">"horse"</span><span class="p">,</span> <span class="s2">"motorbike"</span><span class="p">,</span> <span class="s2">"person"</span><span class="p">,</span> <span class="s2">"pottedplant"</span><span class="p">,</span> <span class="s2">"sheep"</span><span class="p">,</span> <span class="s2">"sofa"</span><span class="p">,</span> <span class="s2">"train"</span><span class="p">,</span> <span class="s2">"tvmonitor"</span><span class="p">,</span> <span class="p">]</span> <span class="c1"># This is used to map between string class to index.</span> <span class="n">CLASS_TO_INDEX</span> <span class="o">=</span> <span class="p">{</span><span class="n">name</span><span class="p">:</span> <span class="n">index</span> <span class="k">for</span> <span class="n">index</span><span class="p">,</span> <span class="n">name</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">CLASSES</span><span class="p">)}</span> <span class="c1"># For the mask data in the PNG file, the encoded raw pixel value need to be</span> <span class="c1"># converted to the proper class index. In the following map, [0, 0, 0] will be</span> <span class="c1"># convert to 0, and [128, 0, 0] will be converted to 1, so on so forth. Also</span> <span class="c1"># note that the mask class is 1 base since class 0 is reserved for the</span> <span class="c1"># background. The [128, 0, 0] (class 1) is mapped to `aeroplane`.</span> <span class="n">VOC_PNG_COLOR_VALUE</span> <span class="o">=</span> <span class="p">[</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="p">[</span><span class="mi">128</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="p">[</span><span class="mi">128</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">128</span><span class="p">],</span> <span class="p">[</span><span class="mi">128</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">128</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="mi">128</span><span class="p">],</span> <span class="p">[</span><span class="mi">128</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="mi">128</span><span class="p">],</span> <span class="p">[</span><span class="mi">64</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="p">[</span><span class="mi">192</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="p">[</span><span class="mi">64</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="p">[</span><span class="mi">192</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="p">[</span><span class="mi">64</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">128</span><span class="p">],</span> <span class="p">[</span><span class="mi">192</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">128</span><span class="p">],</span> <span class="p">[</span><span class="mi">64</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="mi">128</span><span class="p">],</span> <span class="p">[</span><span class="mi">192</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="mi">128</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">64</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="p">[</span><span class="mi">128</span><span class="p">,</span> <span class="mi">64</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">192</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="p">[</span><span class="mi">128</span><span class="p">,</span> <span class="mi">192</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">64</span><span class="p">,</span> <span class="mi">128</span><span class="p">],</span> <span class="p">]</span> <span class="c1"># Will be populated by maybe_populate_voc_color_mapping() below.</span> <span class="n">VOC_PNG_COLOR_MAPPING</span> <span class="o">=</span> <span class="kc">None</span> <span class="k">def</span><span class="w"> </span><span class="nf">maybe_populate_voc_color_mapping</span><span class="p">():</span> <span class="w"> </span><span class="sd">"""Lazy creation of VOC_PNG_COLOR_MAPPING, which could take 64M memory."""</span> <span class="k">global</span> <span class="n">VOC_PNG_COLOR_MAPPING</span> <span class="k">if</span> <span class="n">VOC_PNG_COLOR_MAPPING</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> <span class="n">VOC_PNG_COLOR_MAPPING</span> <span class="o">=</span> <span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="p">(</span><span class="mi">256</span><span class="o">**</span><span class="mi">3</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">colormap</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">VOC_PNG_COLOR_VALUE</span><span class="p">):</span> <span class="n">VOC_PNG_COLOR_MAPPING</span><span class="p">[</span> <span class="p">(</span><span class="n">colormap</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="mi">256</span> <span class="o">+</span> <span class="n">colormap</span><span class="p">[</span><span class="mi">1</span><span class="p">])</span> <span class="o">*</span> <span class="mi">256</span> <span class="o">+</span> <span class="n">colormap</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="p">]</span> <span class="o">=</span> <span class="n">i</span> <span class="c1"># There is a special mapping with [224, 224, 192] -> 255</span> <span class="n">VOC_PNG_COLOR_MAPPING</span><span class="p">[</span><span class="mi">224</span> <span class="o">*</span> <span class="mi">256</span> <span class="o">*</span> <span class="mi">256</span> <span class="o">+</span> <span class="mi">224</span> <span class="o">*</span> <span class="mi">256</span> <span class="o">+</span> <span class="mi">192</span><span class="p">]</span> <span class="o">=</span> <span class="mi">255</span> <span class="n">VOC_PNG_COLOR_MAPPING</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">constant</span><span class="p">(</span><span class="n">VOC_PNG_COLOR_MAPPING</span><span class="p">)</span> <span class="k">return</span> <span class="n">VOC_PNG_COLOR_MAPPING</span> <span class="k">def</span><span class="w"> </span><span class="nf">parse_annotation_data</span><span class="p">(</span><span class="n">annotation_file_path</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Parse the annotation XML file for the image.</span> <span class="sd"> The annotation contains the metadata, as well as the object bounding box</span> <span class="sd"> information.</span> <span class="sd"> """</span> <span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="n">annotation_file_path</span><span class="p">,</span> <span class="s2">"r"</span><span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span> <span class="n">root</span> <span class="o">=</span> <span class="n">xml</span><span class="o">.</span><span class="n">etree</span><span class="o">.</span><span class="n">ElementTree</span><span class="o">.</span><span class="n">parse</span><span class="p">(</span><span class="n">f</span><span class="p">)</span><span class="o">.</span><span class="n">getroot</span><span class="p">()</span> <span class="n">size</span> <span class="o">=</span> <span class="n">root</span><span class="o">.</span><span class="n">find</span><span class="p">(</span><span class="s2">"size"</span><span class="p">)</span> <span class="n">width</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">size</span><span class="o">.</span><span class="n">find</span><span class="p">(</span><span class="s2">"width"</span><span class="p">)</span><span class="o">.</span><span class="n">text</span><span class="p">)</span> <span class="n">height</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">size</span><span class="o">.</span><span class="n">find</span><span class="p">(</span><span class="s2">"height"</span><span class="p">)</span><span class="o">.</span><span class="n">text</span><span class="p">)</span> <span class="n">objects</span> <span class="o">=</span> <span class="p">[]</span> <span class="k">for</span> <span class="n">obj</span> <span class="ow">in</span> <span class="n">root</span><span class="o">.</span><span class="n">findall</span><span class="p">(</span><span class="s2">"object"</span><span class="p">):</span> <span class="c1"># Get object's label name.</span> <span class="n">label</span> <span class="o">=</span> <span class="n">CLASS_TO_INDEX</span><span class="p">[</span><span class="n">obj</span><span class="o">.</span><span class="n">find</span><span class="p">(</span><span class="s2">"name"</span><span class="p">)</span><span class="o">.</span><span class="n">text</span><span class="o">.</span><span class="n">lower</span><span class="p">()]</span> <span class="c1"># Get objects' pose name.</span> <span class="n">pose</span> <span class="o">=</span> <span class="n">obj</span><span class="o">.</span><span class="n">find</span><span class="p">(</span><span class="s2">"pose"</span><span class="p">)</span><span class="o">.</span><span class="n">text</span><span class="o">.</span><span class="n">lower</span><span class="p">()</span> <span class="n">is_truncated</span> <span class="o">=</span> <span class="n">obj</span><span class="o">.</span><span class="n">find</span><span class="p">(</span><span class="s2">"truncated"</span><span class="p">)</span><span class="o">.</span><span class="n">text</span> <span class="o">==</span> <span class="s2">"1"</span> <span class="n">is_difficult</span> <span class="o">=</span> <span class="n">obj</span><span class="o">.</span><span class="n">find</span><span class="p">(</span><span class="s2">"difficult"</span><span class="p">)</span><span class="o">.</span><span class="n">text</span> <span class="o">==</span> <span class="s2">"1"</span> <span class="n">bndbox</span> <span class="o">=</span> <span class="n">obj</span><span class="o">.</span><span class="n">find</span><span class="p">(</span><span class="s2">"bndbox"</span><span class="p">)</span> <span class="n">xmax</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">bndbox</span><span class="o">.</span><span class="n">find</span><span class="p">(</span><span class="s2">"xmax"</span><span class="p">)</span><span class="o">.</span><span class="n">text</span><span class="p">)</span> <span class="n">xmin</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">bndbox</span><span class="o">.</span><span class="n">find</span><span class="p">(</span><span class="s2">"xmin"</span><span class="p">)</span><span class="o">.</span><span class="n">text</span><span class="p">)</span> <span class="n">ymax</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">bndbox</span><span class="o">.</span><span class="n">find</span><span class="p">(</span><span class="s2">"ymax"</span><span class="p">)</span><span class="o">.</span><span class="n">text</span><span class="p">)</span> <span class="n">ymin</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">bndbox</span><span class="o">.</span><span class="n">find</span><span class="p">(</span><span class="s2">"ymin"</span><span class="p">)</span><span class="o">.</span><span class="n">text</span><span class="p">)</span> <span class="n">objects</span><span class="o">.</span><span class="n">append</span><span class="p">(</span> <span class="p">{</span> <span class="s2">"label"</span><span class="p">:</span> <span class="n">label</span><span class="p">,</span> <span class="s2">"pose"</span><span class="p">:</span> <span class="n">pose</span><span class="p">,</span> <span class="s2">"bbox"</span><span class="p">:</span> <span class="p">[</span><span class="n">ymin</span><span class="p">,</span> <span class="n">xmin</span><span class="p">,</span> <span class="n">ymax</span><span class="p">,</span> <span class="n">xmax</span><span class="p">],</span> <span class="s2">"is_truncated"</span><span class="p">:</span> <span class="n">is_truncated</span><span class="p">,</span> <span class="s2">"is_difficult"</span><span class="p">:</span> <span class="n">is_difficult</span><span class="p">,</span> <span class="p">}</span> <span class="p">)</span> <span class="k">return</span> <span class="p">{</span><span class="s2">"width"</span><span class="p">:</span> <span class="n">width</span><span class="p">,</span> <span class="s2">"height"</span><span class="p">:</span> <span class="n">height</span><span class="p">,</span> <span class="s2">"objects"</span><span class="p">:</span> <span class="n">objects</span><span class="p">}</span> <span class="k">def</span><span class="w"> </span><span class="nf">get_image_ids</span><span class="p">(</span><span class="n">data_dir</span><span class="p">,</span> <span class="n">split</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""To get image ids from the "train", "eval" or "trainval" files of VOC data."""</span> <span class="n">data_file_mapping</span> <span class="o">=</span> <span class="p">{</span> <span class="s2">"train"</span><span class="p">:</span> <span class="s2">"train.txt"</span><span class="p">,</span> <span class="s2">"eval"</span><span class="p">:</span> <span class="s2">"val.txt"</span><span class="p">,</span> <span class="s2">"trainval"</span><span class="p">:</span> <span class="s2">"trainval.txt"</span><span class="p">,</span> <span class="p">}</span> <span class="k">with</span> <span class="nb">open</span><span class="p">(</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">data_dir</span><span class="p">,</span> <span class="s2">"ImageSets"</span><span class="p">,</span> <span class="s2">"Segmentation"</span><span class="p">,</span> <span class="n">data_file_mapping</span><span class="p">[</span><span class="n">split</span><span class="p">]),</span> <span class="s2">"r"</span><span class="p">,</span> <span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span> <span class="n">image_ids</span> <span class="o">=</span> <span class="n">f</span><span class="o">.</span><span class="n">read</span><span class="p">()</span><span class="o">.</span><span class="n">splitlines</span><span class="p">()</span> <span class="n">logging</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Received </span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">image_ids</span><span class="p">)</span><span class="si">}</span><span class="s2"> images for </span><span class="si">{</span><span class="n">split</span><span class="si">}</span><span class="s2"> dataset."</span><span class="p">)</span> <span class="k">return</span> <span class="n">image_ids</span> <span class="k">def</span><span class="w"> </span><span class="nf">get_sbd_image_ids</span><span class="p">(</span><span class="n">data_dir</span><span class="p">,</span> <span class="n">split</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""To get image ids from the "sbd_train", "sbd_eval" from files of SBD data."""</span> <span class="n">data_file_mapping</span> <span class="o">=</span> <span class="p">{</span><span class="s2">"sbd_train"</span><span class="p">:</span> <span class="s2">"train.txt"</span><span class="p">,</span> <span class="s2">"sbd_eval"</span><span class="p">:</span> <span class="s2">"val.txt"</span><span class="p">}</span> <span class="k">with</span> <span class="nb">open</span><span class="p">(</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">data_dir</span><span class="p">,</span> <span class="n">data_file_mapping</span><span class="p">[</span><span class="n">split</span><span class="p">]),</span> <span class="s2">"r"</span><span class="p">,</span> <span class="p">)</span> <span class="k">as</span> <span class="n">f</span><span class="p">:</span> <span class="n">image_ids</span> <span class="o">=</span> <span class="n">f</span><span class="o">.</span><span class="n">read</span><span class="p">()</span><span class="o">.</span><span class="n">splitlines</span><span class="p">()</span> <span class="n">logging</span><span class="o">.</span><span class="n">info</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Received </span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">image_ids</span><span class="p">)</span><span class="si">}</span><span class="s2"> images for </span><span class="si">{</span><span class="n">split</span><span class="si">}</span><span class="s2"> dataset."</span><span class="p">)</span> <span class="k">return</span> <span class="n">image_ids</span> <span class="k">def</span><span class="w"> </span><span class="nf">parse_single_image</span><span class="p">(</span><span class="n">image_file_path</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Creates metadata of VOC images and path."""</span> <span class="n">data_dir</span><span class="p">,</span> <span class="n">image_file_name</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">image_file_path</span><span class="p">)</span> <span class="n">data_dir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">normpath</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">data_dir</span><span class="p">,</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">pardir</span><span class="p">))</span> <span class="n">image_id</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">splitext</span><span class="p">(</span><span class="n">image_file_name</span><span class="p">)</span> <span class="n">class_segmentation_file_path</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span> <span class="n">data_dir</span><span class="p">,</span> <span class="s2">"SegmentationClass"</span><span class="p">,</span> <span class="n">image_id</span> <span class="o">+</span> <span class="s2">".png"</span> <span class="p">)</span> <span class="n">object_segmentation_file_path</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span> <span class="n">data_dir</span><span class="p">,</span> <span class="s2">"SegmentationObject"</span><span class="p">,</span> <span class="n">image_id</span> <span class="o">+</span> <span class="s2">".png"</span> <span class="p">)</span> <span class="n">annotation_file_path</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">data_dir</span><span class="p">,</span> <span class="s2">"Annotations"</span><span class="p">,</span> <span class="n">image_id</span> <span class="o">+</span> <span class="s2">".xml"</span><span class="p">)</span> <span class="n">image_annotations</span> <span class="o">=</span> <span class="n">parse_annotation_data</span><span class="p">(</span><span class="n">annotation_file_path</span><span class="p">)</span> <span class="n">result</span> <span class="o">=</span> <span class="p">{</span> <span class="s2">"image/filename"</span><span class="p">:</span> <span class="n">image_id</span> <span class="o">+</span> <span class="s2">".jpg"</span><span class="p">,</span> <span class="s2">"image/file_path"</span><span class="p">:</span> <span class="n">image_file_path</span><span class="p">,</span> <span class="s2">"segmentation/class/file_path"</span><span class="p">:</span> <span class="n">class_segmentation_file_path</span><span class="p">,</span> <span class="s2">"segmentation/object/file_path"</span><span class="p">:</span> <span class="n">object_segmentation_file_path</span><span class="p">,</span> <span class="p">}</span> <span class="n">result</span><span class="o">.</span><span class="n">update</span><span class="p">(</span><span class="n">image_annotations</span><span class="p">)</span> <span class="c1"># Labels field should be same as the 'object.label'</span> <span class="n">labels</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">set</span><span class="p">([</span><span class="n">o</span><span class="p">[</span><span class="s2">"label"</span><span class="p">]</span> <span class="k">for</span> <span class="n">o</span> <span class="ow">in</span> <span class="n">result</span><span class="p">[</span><span class="s2">"objects"</span><span class="p">]]))</span> <span class="n">result</span><span class="p">[</span><span class="s2">"labels"</span><span class="p">]</span> <span class="o">=</span> <span class="nb">sorted</span><span class="p">(</span><span class="n">labels</span><span class="p">)</span> <span class="k">return</span> <span class="n">result</span> <span class="k">def</span><span class="w"> </span><span class="nf">parse_single_sbd_image</span><span class="p">(</span><span class="n">image_file_path</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Creates metadata of SBD images and path."""</span> <span class="n">data_dir</span><span class="p">,</span> <span class="n">image_file_name</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">image_file_path</span><span class="p">)</span> <span class="n">data_dir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">normpath</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">data_dir</span><span class="p">,</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">pardir</span><span class="p">))</span> <span class="n">image_id</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">splitext</span><span class="p">(</span><span class="n">image_file_name</span><span class="p">)</span> <span class="n">class_segmentation_file_path</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">data_dir</span><span class="p">,</span> <span class="s2">"cls"</span><span class="p">,</span> <span class="n">image_id</span> <span class="o">+</span> <span class="s2">".mat"</span><span class="p">)</span> <span class="n">object_segmentation_file_path</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">data_dir</span><span class="p">,</span> <span class="s2">"inst"</span><span class="p">,</span> <span class="n">image_id</span> <span class="o">+</span> <span class="s2">".mat"</span><span class="p">)</span> <span class="n">result</span> <span class="o">=</span> <span class="p">{</span> <span class="s2">"image/filename"</span><span class="p">:</span> <span class="n">image_id</span> <span class="o">+</span> <span class="s2">".jpg"</span><span class="p">,</span> <span class="s2">"image/file_path"</span><span class="p">:</span> <span class="n">image_file_path</span><span class="p">,</span> <span class="s2">"segmentation/class/file_path"</span><span class="p">:</span> <span class="n">class_segmentation_file_path</span><span class="p">,</span> <span class="s2">"segmentation/object/file_path"</span><span class="p">:</span> <span class="n">object_segmentation_file_path</span><span class="p">,</span> <span class="p">}</span> <span class="k">return</span> <span class="n">result</span> <span class="k">def</span><span class="w"> </span><span class="nf">build_metadata</span><span class="p">(</span><span class="n">data_dir</span><span class="p">,</span> <span class="n">image_ids</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Transpose the metadata which convert from list of dict to dict of list."""</span> <span class="c1"># Parallel process all the images.</span> <span class="n">image_file_paths</span> <span class="o">=</span> <span class="p">[</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">data_dir</span><span class="p">,</span> <span class="s2">"JPEGImages"</span><span class="p">,</span> <span class="n">i</span> <span class="o">+</span> <span class="s2">".jpg"</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">image_ids</span> <span class="p">]</span> <span class="n">pool_size</span> <span class="o">=</span> <span class="mi">10</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">image_ids</span><span class="p">)</span> <span class="o">></span> <span class="mi">10</span> <span class="k">else</span> <span class="nb">len</span><span class="p">(</span><span class="n">image_ids</span><span class="p">)</span> <span class="k">with</span> <span class="n">multiprocessing</span><span class="o">.</span><span class="n">Pool</span><span class="p">(</span><span class="n">pool_size</span><span class="p">)</span> <span class="k">as</span> <span class="n">p</span><span class="p">:</span> <span class="n">metadata</span> <span class="o">=</span> <span class="n">p</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">parse_single_image</span><span class="p">,</span> <span class="n">image_file_paths</span><span class="p">)</span> <span class="n">keys</span> <span class="o">=</span> <span class="p">[</span> <span class="s2">"image/filename"</span><span class="p">,</span> <span class="s2">"image/file_path"</span><span class="p">,</span> <span class="s2">"segmentation/class/file_path"</span><span class="p">,</span> <span class="s2">"segmentation/object/file_path"</span><span class="p">,</span> <span class="s2">"labels"</span><span class="p">,</span> <span class="s2">"width"</span><span class="p">,</span> <span class="s2">"height"</span><span class="p">,</span> <span class="p">]</span> <span class="n">result</span> <span class="o">=</span> <span class="p">{}</span> <span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="n">keys</span><span class="p">:</span> <span class="n">values</span> <span class="o">=</span> <span class="p">[</span><span class="n">value</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="k">for</span> <span class="n">value</span> <span class="ow">in</span> <span class="n">metadata</span><span class="p">]</span> <span class="n">result</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">values</span> <span class="c1"># The ragged objects need some special handling</span> <span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="p">[</span><span class="s2">"label"</span><span class="p">,</span> <span class="s2">"pose"</span><span class="p">,</span> <span class="s2">"bbox"</span><span class="p">,</span> <span class="s2">"is_truncated"</span><span class="p">,</span> <span class="s2">"is_difficult"</span><span class="p">]:</span> <span class="n">values</span> <span class="o">=</span> <span class="p">[]</span> <span class="n">objects</span> <span class="o">=</span> <span class="p">[</span><span class="n">value</span><span class="p">[</span><span class="s2">"objects"</span><span class="p">]</span> <span class="k">for</span> <span class="n">value</span> <span class="ow">in</span> <span class="n">metadata</span><span class="p">]</span> <span class="k">for</span> <span class="nb">object</span> <span class="ow">in</span> <span class="n">objects</span><span class="p">:</span> <span class="n">values</span><span class="o">.</span><span class="n">append</span><span class="p">([</span><span class="n">o</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="k">for</span> <span class="n">o</span> <span class="ow">in</span> <span class="nb">object</span><span class="p">])</span> <span class="n">result</span><span class="p">[</span><span class="s2">"objects/"</span> <span class="o">+</span> <span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">values</span> <span class="k">return</span> <span class="n">result</span> <span class="k">def</span><span class="w"> </span><span class="nf">build_sbd_metadata</span><span class="p">(</span><span class="n">data_dir</span><span class="p">,</span> <span class="n">image_ids</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Transpose the metadata which convert from list of dict to dict of list."""</span> <span class="c1"># Parallel process all the images.</span> <span class="n">image_file_paths</span> <span class="o">=</span> <span class="p">[</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">data_dir</span><span class="p">,</span> <span class="s2">"img"</span><span class="p">,</span> <span class="n">i</span> <span class="o">+</span> <span class="s2">".jpg"</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">image_ids</span><span class="p">]</span> <span class="n">pool_size</span> <span class="o">=</span> <span class="mi">10</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="n">image_ids</span><span class="p">)</span> <span class="o">></span> <span class="mi">10</span> <span class="k">else</span> <span class="nb">len</span><span class="p">(</span><span class="n">image_ids</span><span class="p">)</span> <span class="k">with</span> <span class="n">multiprocessing</span><span class="o">.</span><span class="n">Pool</span><span class="p">(</span><span class="n">pool_size</span><span class="p">)</span> <span class="k">as</span> <span class="n">p</span><span class="p">:</span> <span class="n">metadata</span> <span class="o">=</span> <span class="n">p</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">parse_single_sbd_image</span><span class="p">,</span> <span class="n">image_file_paths</span><span class="p">)</span> <span class="n">keys</span> <span class="o">=</span> <span class="p">[</span> <span class="s2">"image/filename"</span><span class="p">,</span> <span class="s2">"image/file_path"</span><span class="p">,</span> <span class="s2">"segmentation/class/file_path"</span><span class="p">,</span> <span class="s2">"segmentation/object/file_path"</span><span class="p">,</span> <span class="p">]</span> <span class="n">result</span> <span class="o">=</span> <span class="p">{}</span> <span class="k">for</span> <span class="n">key</span> <span class="ow">in</span> <span class="n">keys</span><span class="p">:</span> <span class="n">values</span> <span class="o">=</span> <span class="p">[</span><span class="n">value</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="k">for</span> <span class="n">value</span> <span class="ow">in</span> <span class="n">metadata</span><span class="p">]</span> <span class="n">result</span><span class="p">[</span><span class="n">key</span><span class="p">]</span> <span class="o">=</span> <span class="n">values</span> <span class="k">return</span> <span class="n">result</span> <span class="k">def</span><span class="w"> </span><span class="nf">decode_png_mask</span><span class="p">(</span><span class="n">mask</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Decode the raw PNG image and convert it to 2D tensor with probably</span> <span class="sd"> class."""</span> <span class="c1"># Cast the mask to int32 since the original uint8 will overflow when</span> <span class="c1"># multiplied with 256</span> <span class="n">mask</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">tf</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span> <span class="n">mask</span> <span class="o">=</span> <span class="n">mask</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="mi">256</span> <span class="o">*</span> <span class="mi">256</span> <span class="o">+</span> <span class="n">mask</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="mi">256</span> <span class="o">+</span> <span class="n">mask</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">2</span><span class="p">]</span> <span class="n">mask</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">gather</span><span class="p">(</span><span class="n">VOC_PNG_COLOR_MAPPING</span><span class="p">,</span> <span class="n">mask</span><span class="p">),</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="n">mask</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="n">tf</span><span class="o">.</span><span class="n">uint8</span><span class="p">)</span> <span class="k">return</span> <span class="n">mask</span> <span class="k">def</span><span class="w"> </span><span class="nf">load_images</span><span class="p">(</span><span class="n">example</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Loads VOC images for segmentation task from the provided paths"""</span> <span class="n">image_file_path</span> <span class="o">=</span> <span class="n">example</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s2">"image/file_path"</span><span class="p">)</span> <span class="n">segmentation_class_file_path</span> <span class="o">=</span> <span class="n">example</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s2">"segmentation/class/file_path"</span><span class="p">)</span> <span class="n">segmentation_object_file_path</span> <span class="o">=</span> <span class="n">example</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s2">"segmentation/object/file_path"</span><span class="p">)</span> <span class="n">image</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">io</span><span class="o">.</span><span class="n">read_file</span><span class="p">(</span><span class="n">image_file_path</span><span class="p">)</span> <span class="n">image</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">decode_jpeg</span><span class="p">(</span><span class="n">image</span><span class="p">)</span> <span class="n">segmentation_class_mask</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">io</span><span class="o">.</span><span class="n">read_file</span><span class="p">(</span><span class="n">segmentation_class_file_path</span><span class="p">)</span> <span class="n">segmentation_class_mask</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">decode_png</span><span class="p">(</span><span class="n">segmentation_class_mask</span><span class="p">)</span> <span class="n">segmentation_class_mask</span> <span class="o">=</span> <span class="n">decode_png_mask</span><span class="p">(</span><span class="n">segmentation_class_mask</span><span class="p">)</span> <span class="n">segmentation_object_mask</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">io</span><span class="o">.</span><span class="n">read_file</span><span class="p">(</span><span class="n">segmentation_object_file_path</span><span class="p">)</span> <span class="n">segmentation_object_mask</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">decode_png</span><span class="p">(</span><span class="n">segmentation_object_mask</span><span class="p">)</span> <span class="n">segmentation_object_mask</span> <span class="o">=</span> <span class="n">decode_png_mask</span><span class="p">(</span><span class="n">segmentation_object_mask</span><span class="p">)</span> <span class="n">example</span><span class="o">.</span><span class="n">update</span><span class="p">(</span> <span class="p">{</span> <span class="s2">"image"</span><span class="p">:</span> <span class="n">image</span><span class="p">,</span> <span class="s2">"class_segmentation"</span><span class="p">:</span> <span class="n">segmentation_class_mask</span><span class="p">,</span> <span class="s2">"object_segmentation"</span><span class="p">:</span> <span class="n">segmentation_object_mask</span><span class="p">,</span> <span class="p">}</span> <span class="p">)</span> <span class="k">return</span> <span class="n">example</span> <span class="k">def</span><span class="w"> </span><span class="nf">load_sbd_images</span><span class="p">(</span><span class="n">image_file_path</span><span class="p">,</span> <span class="n">seg_cls_file_path</span><span class="p">,</span> <span class="n">seg_obj_file_path</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Loads SBD images for segmentation task from the provided paths"""</span> <span class="n">image</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">io</span><span class="o">.</span><span class="n">read_file</span><span class="p">(</span><span class="n">image_file_path</span><span class="p">)</span> <span class="n">image</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">decode_jpeg</span><span class="p">(</span><span class="n">image</span><span class="p">)</span> <span class="n">segmentation_class_mask</span> <span class="o">=</span> <span class="n">tfds</span><span class="o">.</span><span class="n">core</span><span class="o">.</span><span class="n">lazy_imports</span><span class="o">.</span><span class="n">scipy</span><span class="o">.</span><span class="n">io</span><span class="o">.</span><span class="n">loadmat</span><span class="p">(</span><span class="n">seg_cls_file_path</span><span class="p">)</span> <span class="n">segmentation_class_mask</span> <span class="o">=</span> <span class="n">segmentation_class_mask</span><span class="p">[</span><span class="s2">"GTcls"</span><span class="p">][</span><span class="s2">"Segmentation"</span><span class="p">][</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span> <span class="n">segmentation_class_mask</span> <span class="o">=</span> <span class="n">segmentation_class_mask</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">newaxis</span><span class="p">]</span> <span class="n">segmentation_object_mask</span> <span class="o">=</span> <span class="n">tfds</span><span class="o">.</span><span class="n">core</span><span class="o">.</span><span class="n">lazy_imports</span><span class="o">.</span><span class="n">scipy</span><span class="o">.</span><span class="n">io</span><span class="o">.</span><span class="n">loadmat</span><span class="p">(</span> <span class="n">seg_obj_file_path</span> <span class="p">)</span> <span class="n">segmentation_object_mask</span> <span class="o">=</span> <span class="n">segmentation_object_mask</span><span class="p">[</span><span class="s2">"GTinst"</span><span class="p">][</span><span class="s2">"Segmentation"</span><span class="p">][</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span> <span class="n">segmentation_object_mask</span> <span class="o">=</span> <span class="n">segmentation_object_mask</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">newaxis</span><span class="p">]</span> <span class="k">return</span> <span class="p">{</span> <span class="s2">"image"</span><span class="p">:</span> <span class="n">image</span><span class="p">,</span> <span class="s2">"class_segmentation"</span><span class="p">:</span> <span class="n">segmentation_class_mask</span><span class="p">,</span> <span class="s2">"object_segmentation"</span><span class="p">:</span> <span class="n">segmentation_object_mask</span><span class="p">,</span> <span class="p">}</span> <span class="k">def</span><span class="w"> </span><span class="nf">build_dataset_from_metadata</span><span class="p">(</span><span class="n">metadata</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Builds TensorFlow dataset from the image metadata of VOC dataset."""</span> <span class="c1"># The objects need some manual conversion to ragged tensor.</span> <span class="n">metadata</span><span class="p">[</span><span class="s2">"labels"</span><span class="p">]</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">ragged</span><span class="o">.</span><span class="n">constant</span><span class="p">(</span><span class="n">metadata</span><span class="p">[</span><span class="s2">"labels"</span><span class="p">])</span> <span class="n">metadata</span><span class="p">[</span><span class="s2">"objects/label"</span><span class="p">]</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">ragged</span><span class="o">.</span><span class="n">constant</span><span class="p">(</span><span class="n">metadata</span><span class="p">[</span><span class="s2">"objects/label"</span><span class="p">])</span> <span class="n">metadata</span><span class="p">[</span><span class="s2">"objects/pose"</span><span class="p">]</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">ragged</span><span class="o">.</span><span class="n">constant</span><span class="p">(</span><span class="n">metadata</span><span class="p">[</span><span class="s2">"objects/pose"</span><span class="p">])</span> <span class="n">metadata</span><span class="p">[</span><span class="s2">"objects/is_truncated"</span><span class="p">]</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">ragged</span><span class="o">.</span><span class="n">constant</span><span class="p">(</span> <span class="n">metadata</span><span class="p">[</span><span class="s2">"objects/is_truncated"</span><span class="p">]</span> <span class="p">)</span> <span class="n">metadata</span><span class="p">[</span><span class="s2">"objects/is_difficult"</span><span class="p">]</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">ragged</span><span class="o">.</span><span class="n">constant</span><span class="p">(</span> <span class="n">metadata</span><span class="p">[</span><span class="s2">"objects/is_difficult"</span><span class="p">]</span> <span class="p">)</span> <span class="n">metadata</span><span class="p">[</span><span class="s2">"objects/bbox"</span><span class="p">]</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">ragged</span><span class="o">.</span><span class="n">constant</span><span class="p">(</span> <span class="n">metadata</span><span class="p">[</span><span class="s2">"objects/bbox"</span><span class="p">],</span> <span class="n">ragged_rank</span><span class="o">=</span><span class="mi">1</span> <span class="p">)</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">from_tensor_slices</span><span class="p">(</span><span class="n">metadata</span><span class="p">)</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">load_images</span><span class="p">,</span> <span class="n">num_parallel_calls</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">AUTOTUNE</span><span class="p">)</span> <span class="k">return</span> <span class="n">dataset</span> <span class="k">def</span><span class="w"> </span><span class="nf">build_sbd_dataset_from_metadata</span><span class="p">(</span><span class="n">metadata</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Builds TensorFlow dataset from the image metadata of SBD dataset."""</span> <span class="n">img_filepath</span> <span class="o">=</span> <span class="n">metadata</span><span class="p">[</span><span class="s2">"image/file_path"</span><span class="p">]</span> <span class="n">cls_filepath</span> <span class="o">=</span> <span class="n">metadata</span><span class="p">[</span><span class="s2">"segmentation/class/file_path"</span><span class="p">]</span> <span class="n">obj_filepath</span> <span class="o">=</span> <span class="n">metadata</span><span class="p">[</span><span class="s2">"segmentation/object/file_path"</span><span class="p">]</span> <span class="k">def</span><span class="w"> </span><span class="nf">md_gen</span><span class="p">():</span> <span class="n">c</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="n">img_filepath</span><span class="p">,</span> <span class="n">cls_filepath</span><span class="p">,</span> <span class="n">obj_filepath</span><span class="p">))</span> <span class="c1"># random shuffling for each generator boosts up the quality.</span> <span class="n">random</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="n">c</span><span class="p">)</span> <span class="k">for</span> <span class="n">fp</span> <span class="ow">in</span> <span class="n">c</span><span class="p">:</span> <span class="n">img_fp</span><span class="p">,</span> <span class="n">cls_fp</span><span class="p">,</span> <span class="n">obj_fp</span> <span class="o">=</span> <span class="n">fp</span> <span class="k">yield</span> <span class="n">load_sbd_images</span><span class="p">(</span><span class="n">img_fp</span><span class="p">,</span> <span class="n">cls_fp</span><span class="p">,</span> <span class="n">obj_fp</span><span class="p">)</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">from_generator</span><span class="p">(</span> <span class="n">md_gen</span><span class="p">,</span> <span class="n">output_signature</span><span class="o">=</span><span class="p">(</span> <span class="p">{</span> <span class="s2">"image"</span><span class="p">:</span> <span class="n">tf</span><span class="o">.</span><span class="n">TensorSpec</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">uint8</span><span class="p">),</span> <span class="s2">"class_segmentation"</span><span class="p">:</span> <span class="n">tf</span><span class="o">.</span><span class="n">TensorSpec</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">uint8</span> <span class="p">),</span> <span class="s2">"object_segmentation"</span><span class="p">:</span> <span class="n">tf</span><span class="o">.</span><span class="n">TensorSpec</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">uint8</span> <span class="p">),</span> <span class="p">}</span> <span class="p">),</span> <span class="p">)</span> <span class="k">return</span> <span class="n">dataset</span> <span class="k">def</span><span class="w"> </span><span class="nf">load</span><span class="p">(</span> <span class="n">split</span><span class="o">=</span><span class="s2">"sbd_train"</span><span class="p">,</span> <span class="n">data_dir</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="p">):</span> <span class="w"> </span><span class="sd">"""Load the Pacal VOC 2012 dataset.</span> <span class="sd"> This function will download the data tar file from remote if needed, and</span> <span class="sd"> untar to the local `data_dir`, and build dataset from it.</span> <span class="sd"> It supports both VOC2012 and Semantic Boundaries Dataset (SBD).</span> <span class="sd"> The returned segmentation masks will be int ranging from [0, num_classes),</span> <span class="sd"> as well as 255 which is the boundary mask.</span> <span class="sd"> Args:</span> <span class="sd"> split: string, can be 'train', 'eval', 'trainval', 'sbd_train', or</span> <span class="sd"> 'sbd_eval'. 'sbd_train' represents the training dataset for SBD</span> <span class="sd"> dataset, while 'train' represents the training dataset for VOC2012</span> <span class="sd"> dataset. Defaults to `sbd_train`.</span> <span class="sd"> data_dir: string, local directory path for the loaded data. This will be</span> <span class="sd"> used to download the data file, and unzip. It will be used as a</span> <span class="sd"> cache directory. Defaults to None, and `~/.keras/pascal_voc_2012`</span> <span class="sd"> will be used.</span> <span class="sd"> """</span> <span class="n">supported_split_value</span> <span class="o">=</span> <span class="p">[</span> <span class="s2">"train"</span><span class="p">,</span> <span class="s2">"eval"</span><span class="p">,</span> <span class="s2">"trainval"</span><span class="p">,</span> <span class="s2">"sbd_train"</span><span class="p">,</span> <span class="s2">"sbd_eval"</span><span class="p">,</span> <span class="p">]</span> <span class="k">if</span> <span class="n">split</span> <span class="ow">not</span> <span class="ow">in</span> <span class="n">supported_split_value</span><span class="p">:</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span> <span class="sa">f</span><span class="s2">"The support value for `split` are </span><span class="si">{</span><span class="n">supported_split_value</span><span class="si">}</span><span class="s2">. "</span> <span class="sa">f</span><span class="s2">"Got: </span><span class="si">{</span><span class="n">split</span><span class="si">}</span><span class="s2">"</span> <span class="p">)</span> <span class="k">if</span> <span class="n">data_dir</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> <span class="n">data_dir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">expanduser</span><span class="p">(</span><span class="n">data_dir</span><span class="p">)</span> <span class="k">if</span> <span class="s2">"sbd"</span> <span class="ow">in</span> <span class="n">split</span><span class="p">:</span> <span class="k">return</span> <span class="n">load_sbd</span><span class="p">(</span><span class="n">split</span><span class="p">,</span> <span class="n">data_dir</span><span class="p">)</span> <span class="k">else</span><span class="p">:</span> <span class="k">return</span> <span class="n">load_voc</span><span class="p">(</span><span class="n">split</span><span class="p">,</span> <span class="n">data_dir</span><span class="p">)</span> <span class="k">def</span><span class="w"> </span><span class="nf">load_voc</span><span class="p">(</span> <span class="n">split</span><span class="o">=</span><span class="s2">"train"</span><span class="p">,</span> <span class="n">data_dir</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="p">):</span> <span class="w"> </span><span class="sd">"""This function will download VOC data from a URL. If the data is already</span> <span class="sd"> present in the cache directory, it will load the data from that directory</span> <span class="sd"> instead.</span> <span class="sd"> """</span> <span class="n">extracted_dir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="s2">"VOCdevkit"</span><span class="p">,</span> <span class="s2">"VOC2012"</span><span class="p">)</span> <span class="n">get_data</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">get_file</span><span class="p">(</span> <span class="n">fname</span><span class="o">=</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">basename</span><span class="p">(</span><span class="n">VOC_URL</span><span class="p">),</span> <span class="n">origin</span><span class="o">=</span><span class="n">VOC_URL</span><span class="p">,</span> <span class="n">cache_dir</span><span class="o">=</span><span class="n">data_dir</span><span class="p">,</span> <span class="n">extract</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="p">)</span> <span class="n">data_dir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">dirname</span><span class="p">(</span><span class="n">get_data</span><span class="p">),</span> <span class="n">extracted_dir</span><span class="p">)</span> <span class="n">image_ids</span> <span class="o">=</span> <span class="n">get_image_ids</span><span class="p">(</span><span class="n">data_dir</span><span class="p">,</span> <span class="n">split</span><span class="p">)</span> <span class="c1"># len(metadata) = #samples, metadata[i] is a dict.</span> <span class="n">metadata</span> <span class="o">=</span> <span class="n">build_metadata</span><span class="p">(</span><span class="n">data_dir</span><span class="p">,</span> <span class="n">image_ids</span><span class="p">)</span> <span class="n">maybe_populate_voc_color_mapping</span><span class="p">()</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">build_dataset_from_metadata</span><span class="p">(</span><span class="n">metadata</span><span class="p">)</span> <span class="k">return</span> <span class="n">dataset</span> <span class="k">def</span><span class="w"> </span><span class="nf">load_sbd</span><span class="p">(</span> <span class="n">split</span><span class="o">=</span><span class="s2">"sbd_train"</span><span class="p">,</span> <span class="n">data_dir</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="p">):</span> <span class="w"> </span><span class="sd">"""This function will download SBD data from a URL. If the data is already</span> <span class="sd"> present in the cache directory, it will load the data from that directory</span> <span class="sd"> instead.</span> <span class="sd"> """</span> <span class="n">extracted_dir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="s2">"benchmark_RELEASE"</span><span class="p">,</span> <span class="s2">"dataset"</span><span class="p">)</span> <span class="n">get_data</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">get_file</span><span class="p">(</span> <span class="n">fname</span><span class="o">=</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">basename</span><span class="p">(</span><span class="n">SBD_URL</span><span class="p">),</span> <span class="n">origin</span><span class="o">=</span><span class="n">SBD_URL</span><span class="p">,</span> <span class="n">cache_dir</span><span class="o">=</span><span class="n">data_dir</span><span class="p">,</span> <span class="n">extract</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="p">)</span> <span class="n">data_dir</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">dirname</span><span class="p">(</span><span class="n">get_data</span><span class="p">),</span> <span class="n">extracted_dir</span><span class="p">)</span> <span class="n">image_ids</span> <span class="o">=</span> <span class="n">get_sbd_image_ids</span><span class="p">(</span><span class="n">data_dir</span><span class="p">,</span> <span class="n">split</span><span class="p">)</span> <span class="c1"># len(metadata) = #samples, metadata[i] is a dict.</span> <span class="n">metadata</span> <span class="o">=</span> <span class="n">build_sbd_metadata</span><span class="p">(</span><span class="n">data_dir</span><span class="p">,</span> <span class="n">image_ids</span><span class="p">)</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">build_sbd_dataset_from_metadata</span><span class="p">(</span><span class="n">metadata</span><span class="p">)</span> <span class="k">return</span> <span class="n">dataset</span> </code></pre></div> <hr /> <h2 id="load-the-dataset">Load the dataset</h2> <p>For training and evaluation, let's use "sbd_train" and "sbd_eval." You can also choose any of these datasets for the <code>load</code> function: 'train', 'eval', 'trainval', 'sbd_train', or 'sbd_eval'. 'sbd_train' represents the training dataset for the SBD dataset, while 'train' represents the training dataset for the VOC2012 dataset.</p> <div class="codehilite"><pre><span></span><code><span class="n">train_ds</span> <span class="o">=</span> <span class="n">load</span><span class="p">(</span><span class="n">split</span><span class="o">=</span><span class="s2">"sbd_train"</span><span class="p">,</span> <span class="n">data_dir</span><span class="o">=</span><span class="s2">"segmentation"</span><span class="p">)</span> <span class="n">eval_ds</span> <span class="o">=</span> <span class="n">load</span><span class="p">(</span><span class="n">split</span><span class="o">=</span><span class="s2">"sbd_eval"</span><span class="p">,</span> <span class="n">data_dir</span><span class="o">=</span><span class="s2">"segmentation"</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="preprocess-the-data">Preprocess the data</h2> <p>The preprocess_inputs utility function preprocesses inputs, converting them into a dictionary containing images and segmentation_masks. Both images and segmentation masks are resized to 512x512. The resulting dataset is then batched into groups of four image and segmentation mask pairs.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="nf">preprocess_inputs</span><span class="p">(</span><span class="n">inputs</span><span class="p">):</span> <span class="k">def</span><span class="w"> </span><span class="nf">unpackage_inputs</span><span class="p">(</span><span class="n">inputs</span><span class="p">):</span> <span class="k">return</span> <span class="p">{</span> <span class="s2">"images"</span><span class="p">:</span> <span class="n">inputs</span><span class="p">[</span><span class="s2">"image"</span><span class="p">],</span> <span class="s2">"segmentation_masks"</span><span class="p">:</span> <span class="n">inputs</span><span class="p">[</span><span class="s2">"class_segmentation"</span><span class="p">],</span> <span class="p">}</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">inputs</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">unpackage_inputs</span><span class="p">)</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">outputs</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Resizing</span><span class="p">(</span><span class="n">height</span><span class="o">=</span><span class="mi">512</span><span class="p">,</span> <span class="n">width</span><span class="o">=</span><span class="mi">512</span><span class="p">))</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">outputs</span><span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="n">drop_remainder</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="k">return</span> <span class="n">outputs</span> <span class="n">train_ds</span> <span class="o">=</span> <span class="n">preprocess_inputs</span><span class="p">(</span><span class="n">train_ds</span><span class="p">)</span> <span class="n">batch</span> <span class="o">=</span> <span class="n">train_ds</span><span class="o">.</span><span class="n">take</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">get_single_element</span><span class="p">()</span> </code></pre></div> <p>A batch of this preprocessed input training data can be visualized using the <code>plot_images_masks</code> function. This function takes a batch of images and segmentation masks and prediction masks as input and displays them in a grid.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="nf">plot_images_masks</span><span class="p">(</span><span class="n">images</span><span class="p">,</span> <span class="n">masks</span><span class="p">,</span> <span class="n">pred_masks</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="n">num_images</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">images</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">8</span><span class="p">,</span> <span class="mi">4</span><span class="p">))</span> <span class="n">rows</span> <span class="o">=</span> <span class="mi">3</span> <span class="k">if</span> <span class="n">pred_masks</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span> <span class="k">else</span> <span class="mi">2</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_images</span><span class="p">):</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplot</span><span class="p">(</span><span class="n">rows</span><span class="p">,</span> <span class="n">num_images</span><span class="p">,</span> <span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">images</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">/</span> <span class="mi">255</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">"off"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplot</span><span class="p">(</span><span class="n">rows</span><span class="p">,</span> <span class="n">num_images</span><span class="p">,</span> <span class="n">num_images</span> <span class="o">+</span> <span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">masks</span><span class="p">[</span><span class="n">i</span><span class="p">])</span> <span class="n">plt</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">"off"</span><span class="p">)</span> <span class="k">if</span> <span class="n">pred_masks</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplot</span><span class="p">(</span><span class="n">rows</span><span class="p">,</span> <span class="n">num_images</span><span class="p">,</span> <span class="n">i</span> <span class="o">+</span> <span class="mi">1</span> <span class="o">+</span> <span class="mi">2</span> <span class="o">*</span> <span class="n">num_images</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">pred_masks</span><span class="p">[</span><span class="n">i</span><span class="p">])</span> <span class="n">plt</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">"off"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> <span class="n">plot_images_masks</span><span class="p">(</span><span class="n">batch</span><span class="p">[</span><span class="s2">"images"</span><span class="p">],</span> <span class="n">batch</span><span class="p">[</span><span class="s2">"segmentation_masks"</span><span class="p">])</span> </code></pre></div> <p><img alt="png" src="/img/guides/semantic_segmentation_deeplab_v3/semantic_segmentation_deeplab_v3_18_0.png" /></p> <p>The preprocessing is applied to the evaluation dataset <code>eval_ds</code>.</p> <div class="codehilite"><pre><span></span><code><span class="n">eval_ds</span> <span class="o">=</span> <span class="n">preprocess_inputs</span><span class="p">(</span><span class="n">eval_ds</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="data-augmentation">Data Augmentation</h2> <p>Keras provides a variety of image augmentation options. In this example, we will use the <code>RandomFlip</code> augmentation to augment the training dataset. The <code>RandomFlip</code> augmentation randomly flips the images in the training dataset horizontally or vertically. This can help to improve the model's robustness to changes in the orientation of the objects in the images.</p> <div class="codehilite"><pre><span></span><code><span class="n">train_ds</span> <span class="o">=</span> <span class="n">train_ds</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">RandomFlip</span><span class="p">())</span> <span class="n">batch</span> <span class="o">=</span> <span class="n">train_ds</span><span class="o">.</span><span class="n">take</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">get_single_element</span><span class="p">()</span> <span class="n">plot_images_masks</span><span class="p">(</span><span class="n">batch</span><span class="p">[</span><span class="s2">"images"</span><span class="p">],</span> <span class="n">batch</span><span class="p">[</span><span class="s2">"segmentation_masks"</span><span class="p">])</span> </code></pre></div> <p><img alt="png" src="/img/guides/semantic_segmentation_deeplab_v3/semantic_segmentation_deeplab_v3_22_0.png" /></p> <hr /> <h2 id="model-configuration">Model Configuration</h2> <p>Please feel free to modify the configurations for model training and note how the training results changes. This is an great exercise to get a better understanding of the training pipeline.</p> <p>The learning rate schedule is used by the optimizer to calculate the learning rate for each epoch. The optimizer then uses the learning rate to update the weights of the model. In this case, the learning rate schedule uses a cosine decay function. A cosine decay function starts high and then decreases over time, eventually reaching zero. The cardinality of the VOC dataset is 2124 with a batch size of 4. The dataset cardinality is important for learning rate decay because it determines how many steps the model will train for. The initial learning rate is proportional to 0.007 and the decay steps are 2124. This means that the learning rate will start at <code>INITIAL_LR</code> and then decrease to zero over 2124 steps. <img alt="png" src="/img/guides/semantic_segmentation_deeplab_v3/learning_rate_schedule.png" /></p> <div class="codehilite"><pre><span></span><code><span class="n">BATCH_SIZE</span> <span class="o">=</span> <span class="mi">4</span> <span class="n">INITIAL_LR</span> <span class="o">=</span> <span class="mf">0.007</span> <span class="o">*</span> <span class="n">BATCH_SIZE</span> <span class="o">/</span> <span class="mi">16</span> <span class="n">EPOCHS</span> <span class="o">=</span> <span class="mi">1</span> <span class="n">NUM_CLASSES</span> <span class="o">=</span> <span class="mi">21</span> <span class="n">learning_rate</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">schedules</span><span class="o">.</span><span class="n">CosineDecay</span><span class="p">(</span> <span class="n">INITIAL_LR</span><span class="p">,</span> <span class="n">decay_steps</span><span class="o">=</span><span class="n">EPOCHS</span> <span class="o">*</span> <span class="mi">2124</span><span class="p">,</span> <span class="p">)</span> </code></pre></div> <p>Let's take the <code>resnet_50_imagenet</code> pretrained weights as a image encoder for the model, this implementation can be used both as DeepLabV3 and DeepLabV3+ with additional decoder block. For DeepLabV3+, we instantiate a DeepLabV3Backbone model by providing <code>low_level_feature_key</code> as <code>P2</code> a pyramid level output to extract features from <code>resnet_50_imagenet</code> which acts as a decoder block. To use this model as DeepLabV3 architecture, ignore the <code>low_level_feature_key</code> which defaults to <code>None</code>.</p> <p>Then we create DeepLabV3ImageSegmenter instance. The <code>num_classes</code> parameter specifies the number of classes that the model will be trained to segment. <code>preprocessor</code> argument to apply preprocessing to image input and masks.</p> <div class="codehilite"><pre><span></span><code><span class="n">image_encoder</span> <span class="o">=</span> <span class="n">keras_hub</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">Backbone</span><span class="o">.</span><span class="n">from_preset</span><span class="p">(</span><span class="s2">"resnet_50_imagenet"</span><span class="p">)</span> <span class="n">deeplab_backbone</span> <span class="o">=</span> <span class="n">keras_hub</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">DeepLabV3Backbone</span><span class="p">(</span> <span class="n">image_encoder</span><span class="o">=</span><span class="n">image_encoder</span><span class="p">,</span> <span class="n">low_level_feature_key</span><span class="o">=</span><span class="s2">"P2"</span><span class="p">,</span> <span class="n">spatial_pyramid_pooling_key</span><span class="o">=</span><span class="s2">"P5"</span><span class="p">,</span> <span class="n">dilation_rates</span><span class="o">=</span><span class="p">[</span><span class="mi">6</span><span class="p">,</span> <span class="mi">12</span><span class="p">,</span> <span class="mi">18</span><span class="p">],</span> <span class="n">upsampling_size</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> <span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">keras_hub</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">DeepLabV3ImageSegmenter</span><span class="p">(</span> <span class="n">backbone</span><span class="o">=</span><span class="n">deeplab_backbone</span><span class="p">,</span> <span class="n">num_classes</span><span class="o">=</span><span class="mi">21</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"softmax"</span><span class="p">,</span> <span class="n">preprocessor</span><span class="o">=</span><span class="n">preprocessor</span><span class="p">,</span> <span class="p">)</span> </code></pre></div> <hr /> <h2 id="compile-the-model">Compile the model</h2> <p>The model.compile() function sets up the training process for the model. It defines the - optimization algorithm - Stochastic Gradient Descent (SGD) - the loss function - categorical cross-entropy - the evaluation metrics - Mean IoU and categorical accuracy</p> <p>Semantic segmentation evaluation metrics:</p> <p>Mean Intersection over Union (MeanIoU): MeanIoU measures how well a semantic segmentation model accurately identifies and delineates different objects or regions in an image. It calculates the overlap between predicted and actual object boundaries, providing a score between 0 and 1, where 1 represents a perfect match.</p> <p>Categorical Accuracy: Categorical Accuracy measures the proportion of correctly classified pixels in an image. It gives a simple percentage indicating how accurately the model predicts the categories of pixels in the entire image.</p> <p>In essence, MeanIoU emphasizes the accuracy of identifying specific object boundaries, while Categorical Accuracy gives a broad overview of overall pixel-level correctness.</p> <div class="codehilite"><pre><span></span><code><span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">SGD</span><span class="p">(</span> <span class="n">learning_rate</span><span class="o">=</span><span class="n">learning_rate</span><span class="p">,</span> <span class="n">weight_decay</span><span class="o">=</span><span class="mf">0.0001</span><span class="p">,</span> <span class="n">momentum</span><span class="o">=</span><span class="mf">0.9</span><span class="p">,</span> <span class="n">clipnorm</span><span class="o">=</span><span class="mf">10.0</span> <span class="p">),</span> <span class="n">loss</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">CategoricalCrossentropy</span><span class="p">(</span><span class="n">from_logits</span><span class="o">=</span><span class="kc">False</span><span class="p">),</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span> <span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">MeanIoU</span><span class="p">(</span> <span class="n">num_classes</span><span class="o">=</span><span class="n">NUM_CLASSES</span><span class="p">,</span> <span class="n">sparse_y_true</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">sparse_y_pred</span><span class="o">=</span><span class="kc">False</span> <span class="p">),</span> <span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">CategoricalAccuracy</span><span class="p">(),</span> <span class="p">],</span> <span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">summary</span><span class="p">()</span> </code></pre></div> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold">Preprocessor: "deep_lab_v3_image_segmenter_preprocessor"</span> </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃<span style="font-weight: bold"> Layer (type) </span>┃<span style="font-weight: bold"> Config </span>┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ deep_lab_v3_image_converter (<span style="color: #0087ff; text-decoration-color: #0087ff">DeepLabV3ImageConverter</span>) │ Image size: (<span style="color: #00af00; text-decoration-color: #00af00">512</span>, <span style="color: #00af00; text-decoration-color: #00af00">512</span>) │ └───────────────────────────────────────────────────────────────┴──────────────────────────────────────────┘ </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold">Model: "deep_lab_v3_image_segmenter"</span> </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓ ┃<span style="font-weight: bold"> Layer (type) </span>┃<span style="font-weight: bold"> Output Shape </span>┃<span style="font-weight: bold"> Param # </span>┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩ │ inputs (<span style="color: #0087ff; text-decoration-color: #0087ff">InputLayer</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">3</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├───────────────────────────────────────────────┼────────────────────────────────────┼─────────────────────┤ │ deep_lab_v3_backbone (<span style="color: #0087ff; text-decoration-color: #0087ff">DeepLabV3Backbone</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">256</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">39,190,656</span> │ ├───────────────────────────────────────────────┼────────────────────────────────────┼─────────────────────┤ │ segmentation_output (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv2D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">21</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">5,376</span> │ └───────────────────────────────────────────────┴────────────────────────────────────┴─────────────────────┘ </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Total params: </span><span style="color: #00af00; text-decoration-color: #00af00">39,196,032</span> (149.52 MB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">39,139,232</span> (149.30 MB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Non-trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">56,800</span> (221.88 KB) </pre> <p>The utility function <code>dict_to_tuple</code> effectively transforms the dictionaries of training and validation datasets into tuples of images and one-hot encoded segmentation masks, which is used during training and evaluation of the DeepLabv3+ model.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="nf">dict_to_tuple</span><span class="p">(</span><span class="n">x</span><span class="p">):</span> <span class="k">return</span> <span class="n">x</span><span class="p">[</span><span class="s2">"images"</span><span class="p">],</span> <span class="n">tf</span><span class="o">.</span><span class="n">one_hot</span><span class="p">(</span> <span class="n">tf</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="n">x</span><span class="p">[</span><span class="s2">"segmentation_masks"</span><span class="p">],</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">),</span> <span class="s2">"int32"</span><span class="p">),</span> <span class="mi">21</span> <span class="p">)</span> <span class="n">train_ds</span> <span class="o">=</span> <span class="n">train_ds</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">dict_to_tuple</span><span class="p">)</span> <span class="n">eval_ds</span> <span class="o">=</span> <span class="n">eval_ds</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">dict_to_tuple</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">train_ds</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="n">eval_ds</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="n">EPOCHS</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 1/Unknown 40s 40s/step - categorical_accuracy: 0.1191 - loss: 3.0568 - mean_io_u: 0.0118 </code></pre></div> </div> <p>2124/2124 ━━━━━━━━━━━━━━━━━━━━ 281s 114ms/step - categorical_accuracy: 0.7286 - loss: 1.0707 - mean_io_u: 0.0926 - val_categorical_accuracy: 0.8199 - val_loss: 0.5900 - val_mean_io_u: 0.3265</p> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code><keras.src.callbacks.history.History at 0x7fd7a897f8d0> </code></pre></div> </div> <hr /> <h2 id="predictions-with-trained-model">Predictions with trained model</h2> <p>Now that the model training of DeepLabv3+ has completed, let's test it by making predications on a few sample images. Note: For demonstration purpose the model has been trained on only 1 epoch, for better accuracy and result train with more number of epochs.</p> <div class="codehilite"><pre><span></span><code><span class="n">test_ds</span> <span class="o">=</span> <span class="n">load</span><span class="p">(</span><span class="n">split</span><span class="o">=</span><span class="s2">"sbd_eval"</span><span class="p">)</span> <span class="n">test_ds</span> <span class="o">=</span> <span class="n">preprocess_inputs</span><span class="p">(</span><span class="n">test_ds</span><span class="p">)</span> <span class="n">images</span><span class="p">,</span> <span class="n">masks</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="nb">iter</span><span class="p">(</span><span class="n">test_ds</span><span class="o">.</span><span class="n">take</span><span class="p">(</span><span class="mi">1</span><span class="p">)))</span> <span class="n">images</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">convert_to_tensor</span><span class="p">(</span><span class="n">images</span><span class="p">)</span> <span class="n">masks</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">convert_to_tensor</span><span class="p">(</span><span class="n">masks</span><span class="p">)</span> <span class="n">preds</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">ops</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">images</span><span class="p">),</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">),</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="n">masks</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">ops</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">masks</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">),</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="n">plot_images_masks</span><span class="p">(</span><span class="n">images</span><span class="p">,</span> <span class="n">masks</span><span class="p">,</span> <span class="n">preds</span><span class="p">)</span> </code></pre></div> <p>1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 3s/step</p> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> </code></pre></div> </div> <p>1/1 ━━━━━━━━━━━━━━━━━━━━ 3s 3s/step</p> <p><img alt="png" src="/img/guides/semantic_segmentation_deeplab_v3/semantic_segmentation_deeplab_v3_32_2.png" /></p> <p>Here are some additional tips for using the KerasHub DeepLabv3 model:</p> <ul> <li>The model can be trained on a variety of datasets, including the COCO dataset, the PASCAL VOC dataset, and the Cityscapes dataset.</li> <li>The model can be fine-tuned on a custom dataset to improve its performance on a specific task.</li> <li>The model can be used to perform real-time inference on images.</li> <li>Also, check out KerasHub's other segmentation models.</li> </ul> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#semantic-segmentation-with-kerashub'>Semantic Segmentation with KerasHub</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#background'>Background</a> </div> <div class='k-outline-depth-3'> <a href='#references'>References</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#setup-and-imports'>Setup and Imports</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#perform-semantic-segmentation-with-a-pretrained-deeplabv3-model'>Perform semantic segmentation with a pretrained DeepLabv3+ model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#train-a-custom-semantic-segmentation-model'>Train a custom semantic segmentation model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#download-the-data'>Download the data</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#load-the-dataset'>Load the dataset</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#preprocess-the-data'>Preprocess the data</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#data-augmentation'>Data Augmentation</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#model-configuration'>Model Configuration</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#compile-the-model'>Compile the model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#predictions-with-trained-model'>Predictions with trained model</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>