CINXE.COM
Point cloud segmentation with PointNet
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/vision/pointnet_segmentation/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Point cloud segmentation with PointNet"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Point cloud segmentation with PointNet"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Point cloud segmentation with PointNet</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink active" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink2" href="/examples/vision/image_classification_from_scratch/">Image classification from scratch</a> <a class="nav-sublink2" href="/examples/vision/mnist_convnet/">Simple MNIST convnet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_efficientnet_fine_tuning/">Image classification via fine-tuning with EfficientNet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_with_vision_transformer/">Image classification with Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/attention_mil_classification/">Classification using Attention-based Deep Multiple Instance Learning</a> <a class="nav-sublink2" href="/examples/vision/mlp_image_classification/">Image classification with modern MLP models</a> <a class="nav-sublink2" href="/examples/vision/mobilevit/">A mobile-friendly Transformer-based model for image classification</a> <a class="nav-sublink2" href="/examples/vision/xray_classification_with_tpus/">Pneumonia Classification on TPU</a> <a class="nav-sublink2" href="/examples/vision/cct/">Compact Convolutional Transformers</a> <a class="nav-sublink2" href="/examples/vision/convmixer/">Image classification with ConvMixer</a> <a class="nav-sublink2" href="/examples/vision/eanet/">Image classification with EANet (External Attention Transformer)</a> <a class="nav-sublink2" href="/examples/vision/involution/">Involutional neural networks</a> <a class="nav-sublink2" href="/examples/vision/perceiver_image_classification/">Image classification with Perceiver</a> <a class="nav-sublink2" href="/examples/vision/reptile/">Few-Shot learning with Reptile</a> <a class="nav-sublink2" href="/examples/vision/semisupervised_simclr/">Semi-supervised image classification using contrastive pretraining with SimCLR</a> <a class="nav-sublink2" href="/examples/vision/swin_transformers/">Image classification with Swin Transformers</a> <a class="nav-sublink2" href="/examples/vision/vit_small_ds/">Train a Vision Transformer on small datasets</a> <a class="nav-sublink2" href="/examples/vision/shiftvit/">A Vision Transformer without Attention</a> <a class="nav-sublink2" href="/examples/vision/image_classification_using_global_context_vision_transformer/">Image Classification using Global Context Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/oxford_pets_image_segmentation/">Image segmentation with a U-Net-like architecture</a> <a class="nav-sublink2" href="/examples/vision/deeplabv3_plus/">Multiclass semantic segmentation using DeepLabV3+</a> <a class="nav-sublink2" href="/examples/vision/basnet_segmentation/">Highly accurate boundaries segmentation using BASNet</a> <a class="nav-sublink2" href="/examples/vision/fully_convolutional_network/">Image Segmentation using Composable Fully-Convolutional Networks</a> <a class="nav-sublink2" href="/examples/vision/retinanet/">Object Detection with RetinaNet</a> <a class="nav-sublink2" href="/examples/vision/keypoint_detection/">Keypoint Detection with Transfer Learning</a> <a class="nav-sublink2" href="/examples/vision/object_detection_using_vision_transformer/">Object detection with Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/3D_image_classification/">3D image classification from CT scans</a> <a class="nav-sublink2" href="/examples/vision/depth_estimation/">Monocular depth estimation</a> <a class="nav-sublink2" href="/examples/vision/nerf/">3D volumetric rendering with NeRF</a> <a class="nav-sublink2 active" href="/examples/vision/pointnet_segmentation/">Point cloud segmentation with PointNet</a> <a class="nav-sublink2" href="/examples/vision/pointnet/">Point cloud classification</a> <a class="nav-sublink2" href="/examples/vision/captcha_ocr/">OCR model for reading Captchas</a> <a class="nav-sublink2" href="/examples/vision/handwriting_recognition/">Handwriting recognition</a> <a class="nav-sublink2" href="/examples/vision/autoencoder/">Convolutional autoencoder for image denoising</a> <a class="nav-sublink2" href="/examples/vision/mirnet/">Low-light image enhancement using MIRNet</a> <a class="nav-sublink2" href="/examples/vision/super_resolution_sub_pixel/">Image Super-Resolution using an Efficient Sub-Pixel CNN</a> <a class="nav-sublink2" href="/examples/vision/edsr/">Enhanced Deep Residual Networks for single-image super-resolution</a> <a class="nav-sublink2" href="/examples/vision/zero_dce/">Zero-DCE for low-light image enhancement</a> <a class="nav-sublink2" href="/examples/vision/cutmix/">CutMix data augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/mixup/">MixUp augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/randaugment/">RandAugment for Image Classification for Improved Robustness</a> <a class="nav-sublink2" href="/examples/vision/image_captioning/">Image captioning</a> <a class="nav-sublink2" href="/examples/vision/nl_image_search/">Natural language image search with a Dual Encoder</a> <a class="nav-sublink2" href="/examples/vision/visualizing_what_convnets_learn/">Visualizing what convnets learn</a> <a class="nav-sublink2" href="/examples/vision/integrated_gradients/">Model interpretability with Integrated Gradients</a> <a class="nav-sublink2" href="/examples/vision/probing_vits/">Investigating Vision Transformer representations</a> <a class="nav-sublink2" href="/examples/vision/grad_cam/">Grad-CAM class activation visualization</a> <a class="nav-sublink2" href="/examples/vision/near_dup_search/">Near-duplicate image search</a> <a class="nav-sublink2" href="/examples/vision/semantic_image_clustering/">Semantic Image Clustering</a> <a class="nav-sublink2" href="/examples/vision/siamese_contrastive/">Image similarity estimation using a Siamese Network with a contrastive loss</a> <a class="nav-sublink2" href="/examples/vision/siamese_network/">Image similarity estimation using a Siamese Network with a triplet loss</a> <a class="nav-sublink2" href="/examples/vision/metric_learning/">Metric learning for image similarity search</a> <a class="nav-sublink2" href="/examples/vision/metric_learning_tf_similarity/">Metric learning for image similarity search using TensorFlow Similarity</a> <a class="nav-sublink2" href="/examples/vision/nnclr/">Self-supervised contrastive learning with NNCLR</a> <a class="nav-sublink2" href="/examples/vision/video_classification/">Video Classification with a CNN-RNN Architecture</a> <a class="nav-sublink2" href="/examples/vision/conv_lstm/">Next-Frame Video Prediction with Convolutional LSTMs</a> <a class="nav-sublink2" href="/examples/vision/video_transformers/">Video Classification with Transformers</a> <a class="nav-sublink2" href="/examples/vision/vivit/">Video Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/bit/">Image Classification using BigTransfer (BiT)</a> <a class="nav-sublink2" href="/examples/vision/gradient_centralization/">Gradient Centralization for Better Training Performance</a> <a class="nav-sublink2" href="/examples/vision/token_learner/">Learning to tokenize in Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/knowledge_distillation/">Knowledge Distillation</a> <a class="nav-sublink2" href="/examples/vision/fixres/">FixRes: Fixing train-test resolution discrepancy</a> <a class="nav-sublink2" href="/examples/vision/cait/">Class Attention Image Transformers with LayerScale</a> <a class="nav-sublink2" href="/examples/vision/patch_convnet/">Augmenting convnets with aggregated attention</a> <a class="nav-sublink2" href="/examples/vision/learnable_resizer/">Learning to Resize</a> <a class="nav-sublink2" href="/examples/vision/adamatch/">Semi-supervision and domain adaptation with AdaMatch</a> <a class="nav-sublink2" href="/examples/vision/barlow_twins/">Barlow Twins for Contrastive SSL</a> <a class="nav-sublink2" href="/examples/vision/consistency_training/">Consistency training with supervision</a> <a class="nav-sublink2" href="/examples/vision/deit/">Distilling Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/focal_modulation_network/">Focal Modulation: A replacement for Self-Attention</a> <a class="nav-sublink2" href="/examples/vision/forwardforward/">Using the Forward-Forward Algorithm for Image Classification</a> <a class="nav-sublink2" href="/examples/vision/masked_image_modeling/">Masked image modeling with Autoencoders</a> <a class="nav-sublink2" href="/examples/vision/sam/">Segment Anything Model with 🤗Transformers</a> <a class="nav-sublink2" href="/examples/vision/segformer/">Semantic segmentation with SegFormer and Hugging Face Transformers</a> <a class="nav-sublink2" href="/examples/vision/simsiam/">Self-supervised contrastive learning with SimSiam</a> <a class="nav-sublink2" href="/examples/vision/supervised-contrastive-learning/">Supervised Contrastive Learning</a> <a class="nav-sublink2" href="/examples/vision/temporal_latent_bottleneck/">When Recurrence meets Transformers</a> <a class="nav-sublink2" href="/examples/vision/yolov8/">Efficient Object Detection with YOLOV8 and KerasCV</a> <a class="nav-sublink" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparam Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/vision/'>Computer Vision</a> / Point cloud segmentation with PointNet </div> <div class='k-content'> <h1 id="point-cloud-segmentation-with-pointnet">Point cloud segmentation with PointNet</h1> <p><strong>Author:</strong> <a href="https://github.com/soumik12345">Soumik Rakshit</a>, <a href="https://github.com/sayakpaul">Sayak Paul</a><br> <strong>Date created:</strong> 2020/10/23<br> <strong>Last modified:</strong> 2020/10/24<br> <strong>Description:</strong> Implementation of a PointNet-based model for segmenting point clouds.</p> <div class='example_version_banner keras_3'>ⓘ This example uses Keras 3</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/vision/ipynb/pointnet_segmentation.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/examples/vision/pointnet_segmentation.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="introduction">Introduction</h2> <p>A "point cloud" is an important type of data structure for storing geometric shape data. Due to its irregular format, it's often transformed into regular 3D voxel grids or collections of images before being used in deep learning applications, a step which makes the data unnecessarily large. The PointNet family of models solves this problem by directly consuming point clouds, respecting the permutation-invariance property of the point data. The PointNet family of models provides a simple, unified architecture for applications ranging from <strong>object classification</strong>, <strong>part segmentation</strong>, to <strong>scene semantic parsing</strong>.</p> <p>In this example, we demonstrate the implementation of the PointNet architecture for shape segmentation.</p> <h3 id="references">References</h3> <ul> <li><a href="https://arxiv.org/abs/1612.00593">PointNet: Deep Learning on Point Sets for 3D Classification and Segmentation</a></li> <li><a href="https://keras.io/examples/vision/pointnet/">Point cloud classification with PointNet</a></li> <li><a href="https://arxiv.org/abs/1506.02025">Spatial Transformer Networks</a></li> </ul> <hr /> <h2 id="imports">Imports</h2> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">os</span> <span class="kn">import</span> <span class="nn">json</span> <span class="kn">import</span> <span class="nn">random</span> <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> <span class="kn">import</span> <span class="nn">pandas</span> <span class="k">as</span> <span class="nn">pd</span> <span class="kn">from</span> <span class="nn">tqdm</span> <span class="kn">import</span> <span class="n">tqdm</span> <span class="kn">from</span> <span class="nn">glob</span> <span class="kn">import</span> <span class="n">glob</span> <span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="nn">tf</span> <span class="c1"># For tf.data</span> <span class="kn">import</span> <span class="nn">keras</span> <span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">layers</span> <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span> </code></pre></div> <hr /> <h2 id="downloading-dataset">Downloading Dataset</h2> <p>The <a href="https://shapenet.org/">ShapeNet dataset</a> is an ongoing effort to establish a richly-annotated, large-scale dataset of 3D shapes. <strong>ShapeNetCore</strong> is a subset of the full ShapeNet dataset with clean single 3D models and manually verified category and alignment annotations. It covers 55 common object categories, with about 51,300 unique 3D models.</p> <p>For this example, we use one of the 12 object categories of <a href="http://cvgl.stanford.edu/projects/pascal3d.html">PASCAL 3D+</a>, included as part of the ShapenetCore dataset.</p> <div class="codehilite"><pre><span></span><code><span class="n">dataset_url</span> <span class="o">=</span> <span class="s2">"https://git.io/JiY4i"</span> <span class="n">dataset_path</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">get_file</span><span class="p">(</span> <span class="n">fname</span><span class="o">=</span><span class="s2">"shapenet.zip"</span><span class="p">,</span> <span class="n">origin</span><span class="o">=</span><span class="n">dataset_url</span><span class="p">,</span> <span class="n">cache_subdir</span><span class="o">=</span><span class="s2">"datasets"</span><span class="p">,</span> <span class="n">hash_algorithm</span><span class="o">=</span><span class="s2">"auto"</span><span class="p">,</span> <span class="n">extract</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">archive_format</span><span class="o">=</span><span class="s2">"auto"</span><span class="p">,</span> <span class="n">cache_dir</span><span class="o">=</span><span class="s2">"datasets"</span><span class="p">,</span> <span class="p">)</span> </code></pre></div> <hr /> <h2 id="loading-the-dataset">Loading the dataset</h2> <p>We parse the dataset metadata in order to easily map model categories to their respective directories and segmentation classes to colors for the purpose of visualization.</p> <div class="codehilite"><pre><span></span><code><span class="k">with</span> <span class="nb">open</span><span class="p">(</span><span class="s2">"/tmp/.keras/datasets/PartAnnotation/metadata.json"</span><span class="p">)</span> <span class="k">as</span> <span class="n">json_file</span><span class="p">:</span> <span class="n">metadata</span> <span class="o">=</span> <span class="n">json</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">json_file</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="n">metadata</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>{'Airplane': {'directory': '02691156', 'lables': ['wing', 'body', 'tail', 'engine'], 'colors': ['blue', 'green', 'red', 'pink']}, 'Bag': {'directory': '02773838', 'lables': ['handle', 'body'], 'colors': ['blue', 'green']}, 'Cap': {'directory': '02954340', 'lables': ['panels', 'peak'], 'colors': ['blue', 'green']}, 'Car': {'directory': '02958343', 'lables': ['wheel', 'hood', 'roof'], 'colors': ['blue', 'green', 'red']}, 'Chair': {'directory': '03001627', 'lables': ['leg', 'arm', 'back', 'seat'], 'colors': ['blue', 'green', 'red', 'pink']}, 'Earphone': {'directory': '03261776', 'lables': ['earphone', 'headband'], 'colors': ['blue', 'green']}, 'Guitar': {'directory': '03467517', 'lables': ['head', 'body', 'neck'], 'colors': ['blue', 'green', 'red']}, 'Knife': {'directory': '03624134', 'lables': ['handle', 'blade'], 'colors': ['blue', 'green']}, 'Lamp': {'directory': '03636649', 'lables': ['canopy', 'lampshade', 'base'], 'colors': ['blue', 'green', 'red']}, 'Laptop': {'directory': '03642806', 'lables': ['keyboard'], 'colors': ['blue']}, 'Motorbike': {'directory': '03790512', 'lables': ['wheel', 'handle', 'gas_tank', 'light', 'seat'], 'colors': ['blue', 'green', 'red', 'pink', 'yellow']}, 'Mug': {'directory': '03797390', 'lables': ['handle'], 'colors': ['blue']}, 'Pistol': {'directory': '03948459', 'lables': ['trigger_and_guard', 'handle', 'barrel'], 'colors': ['blue', 'green', 'red']}, 'Rocket': {'directory': '04099429', 'lables': ['nose', 'body', 'fin'], 'colors': ['blue', 'green', 'red']}, 'Skateboard': {'directory': '04225987', 'lables': ['wheel', 'deck'], 'colors': ['blue', 'green']}, 'Table': {'directory': '04379243', 'lables': ['leg', 'top'], 'colors': ['blue', 'green']}} </code></pre></div> </div> <p>In this example, we train PointNet to segment the parts of an <code>Airplane</code> model.</p> <div class="codehilite"><pre><span></span><code><span class="n">points_dir</span> <span class="o">=</span> <span class="s2">"/tmp/.keras/datasets/PartAnnotation/</span><span class="si">{}</span><span class="s2">/points"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span> <span class="n">metadata</span><span class="p">[</span><span class="s2">"Airplane"</span><span class="p">][</span><span class="s2">"directory"</span><span class="p">]</span> <span class="p">)</span> <span class="n">labels_dir</span> <span class="o">=</span> <span class="s2">"/tmp/.keras/datasets/PartAnnotation/</span><span class="si">{}</span><span class="s2">/points_label"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span> <span class="n">metadata</span><span class="p">[</span><span class="s2">"Airplane"</span><span class="p">][</span><span class="s2">"directory"</span><span class="p">]</span> <span class="p">)</span> <span class="n">LABELS</span> <span class="o">=</span> <span class="n">metadata</span><span class="p">[</span><span class="s2">"Airplane"</span><span class="p">][</span><span class="s2">"lables"</span><span class="p">]</span> <span class="n">COLORS</span> <span class="o">=</span> <span class="n">metadata</span><span class="p">[</span><span class="s2">"Airplane"</span><span class="p">][</span><span class="s2">"colors"</span><span class="p">]</span> <span class="n">VAL_SPLIT</span> <span class="o">=</span> <span class="mf">0.2</span> <span class="n">NUM_SAMPLE_POINTS</span> <span class="o">=</span> <span class="mi">1024</span> <span class="n">BATCH_SIZE</span> <span class="o">=</span> <span class="mi">32</span> <span class="n">EPOCHS</span> <span class="o">=</span> <span class="mi">60</span> <span class="n">INITIAL_LR</span> <span class="o">=</span> <span class="mf">1e-3</span> </code></pre></div> <hr /> <h2 id="structuring-the-dataset">Structuring the dataset</h2> <p>We generate the following in-memory data structures from the Airplane point clouds and their labels:</p> <ul> <li><code>point_clouds</code> is a list of <code>np.array</code> objects that represent the point cloud data in the form of x, y and z coordinates. Axis 0 represents the number of points in the point cloud, while axis 1 represents the coordinates. <code>all_labels</code> is the list that represents the label of each coordinate as a string (needed mainly for visualization purposes).</li> <li><code>test_point_clouds</code> is in the same format as <code>point_clouds</code>, but doesn't have corresponding the labels of the point clouds.</li> <li><code>all_labels</code> is a list of <code>np.array</code> objects that represent the point cloud labels for each coordinate, corresponding to the <code>point_clouds</code> list.</li> <li><code>point_cloud_labels</code> is a list of <code>np.array</code> objects that represent the point cloud labels for each coordinate in one-hot encoded form, corresponding to the <code>point_clouds</code> list.</li> </ul> <div class="codehilite"><pre><span></span><code><span class="n">point_clouds</span><span class="p">,</span> <span class="n">test_point_clouds</span> <span class="o">=</span> <span class="p">[],</span> <span class="p">[]</span> <span class="n">point_cloud_labels</span><span class="p">,</span> <span class="n">all_labels</span> <span class="o">=</span> <span class="p">[],</span> <span class="p">[]</span> <span class="n">points_files</span> <span class="o">=</span> <span class="n">glob</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">points_dir</span><span class="p">,</span> <span class="s2">"*.pts"</span><span class="p">))</span> <span class="k">for</span> <span class="n">point_file</span> <span class="ow">in</span> <span class="n">tqdm</span><span class="p">(</span><span class="n">points_files</span><span class="p">):</span> <span class="n">point_cloud</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">loadtxt</span><span class="p">(</span><span class="n">point_file</span><span class="p">)</span> <span class="k">if</span> <span class="n">point_cloud</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o"><</span> <span class="n">NUM_SAMPLE_POINTS</span><span class="p">:</span> <span class="k">continue</span> <span class="c1"># Get the file-id of the current point cloud for parsing its</span> <span class="c1"># labels.</span> <span class="n">file_id</span> <span class="o">=</span> <span class="n">point_file</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">"/"</span><span class="p">)[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">"."</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span> <span class="n">label_data</span><span class="p">,</span> <span class="n">num_labels</span> <span class="o">=</span> <span class="p">{},</span> <span class="mi">0</span> <span class="k">for</span> <span class="n">label</span> <span class="ow">in</span> <span class="n">LABELS</span><span class="p">:</span> <span class="n">label_file</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">labels_dir</span><span class="p">,</span> <span class="n">label</span><span class="p">,</span> <span class="n">file_id</span> <span class="o">+</span> <span class="s2">".seg"</span><span class="p">)</span> <span class="k">if</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="n">label_file</span><span class="p">):</span> <span class="n">label_data</span><span class="p">[</span><span class="n">label</span><span class="p">]</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">loadtxt</span><span class="p">(</span><span class="n">label_file</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">"float32"</span><span class="p">)</span> <span class="n">num_labels</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">label_data</span><span class="p">[</span><span class="n">label</span><span class="p">])</span> <span class="c1"># Point clouds having labels will be our training samples.</span> <span class="k">try</span><span class="p">:</span> <span class="n">label_map</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"none"</span><span class="p">]</span> <span class="o">*</span> <span class="n">num_labels</span> <span class="k">for</span> <span class="n">label</span> <span class="ow">in</span> <span class="n">LABELS</span><span class="p">:</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">data</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">label_data</span><span class="p">[</span><span class="n">label</span><span class="p">]):</span> <span class="n">label_map</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">label</span> <span class="k">if</span> <span class="n">data</span> <span class="o">==</span> <span class="mi">1</span> <span class="k">else</span> <span class="n">label_map</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="n">label_data</span> <span class="o">=</span> <span class="p">[</span> <span class="n">LABELS</span><span class="o">.</span><span class="n">index</span><span class="p">(</span><span class="n">label</span><span class="p">)</span> <span class="k">if</span> <span class="n">label</span> <span class="o">!=</span> <span class="s2">"none"</span> <span class="k">else</span> <span class="nb">len</span><span class="p">(</span><span class="n">LABELS</span><span class="p">)</span> <span class="k">for</span> <span class="n">label</span> <span class="ow">in</span> <span class="n">label_map</span> <span class="p">]</span> <span class="c1"># Apply one-hot encoding to the dense label representation.</span> <span class="n">label_data</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">to_categorical</span><span class="p">(</span><span class="n">label_data</span><span class="p">,</span> <span class="n">num_classes</span><span class="o">=</span><span class="nb">len</span><span class="p">(</span><span class="n">LABELS</span><span class="p">)</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="n">point_clouds</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">point_cloud</span><span class="p">)</span> <span class="n">point_cloud_labels</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">label_data</span><span class="p">)</span> <span class="n">all_labels</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">label_map</span><span class="p">)</span> <span class="k">except</span> <span class="ne">KeyError</span><span class="p">:</span> <span class="n">test_point_clouds</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">point_cloud</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>100%|██████████████████████████████████████████████████████████████████████| 4045/4045 [01:30<00:00, 44.54it/s] </code></pre></div> </div> <p>Next, we take a look at some samples from the in-memory arrays we just generated:</p> <div class="codehilite"><pre><span></span><code><span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">5</span><span class="p">):</span> <span class="n">i</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">point_clouds</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"point_clouds[</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s2">].shape:"</span><span class="p">,</span> <span class="n">point_clouds</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"point_cloud_labels[</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s2">].shape:"</span><span class="p">,</span> <span class="n">point_cloud_labels</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">5</span><span class="p">):</span> <span class="nb">print</span><span class="p">(</span> <span class="sa">f</span><span class="s2">"all_labels[</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s2">][</span><span class="si">{</span><span class="n">j</span><span class="si">}</span><span class="s2">]:"</span><span class="p">,</span> <span class="n">all_labels</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">j</span><span class="p">],</span> <span class="sa">f</span><span class="s2">"</span><span class="se">\t</span><span class="s2">point_cloud_labels[</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s2">][</span><span class="si">{</span><span class="n">j</span><span class="si">}</span><span class="s2">]:"</span><span class="p">,</span> <span class="n">point_cloud_labels</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">j</span><span class="p">],</span> <span class="s2">"</span><span class="se">\n</span><span class="s2">"</span><span class="p">,</span> <span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>point_clouds[333].shape: (2571, 3) point_cloud_labels[333].shape: (2571, 5) all_labels[333][0]: tail point_cloud_labels[333][0]: [0. 0. 1. 0. 0.] </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>all_labels[333][1]: wing point_cloud_labels[333][1]: [1. 0. 0. 0. 0.] </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>all_labels[333][2]: tail point_cloud_labels[333][2]: [0. 0. 1. 0. 0.] </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>all_labels[333][3]: engine point_cloud_labels[333][3]: [0. 0. 0. 1. 0.] </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>all_labels[333][4]: wing point_cloud_labels[333][4]: [1. 0. 0. 0. 0.] </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>point_clouds[3273].shape: (2571, 3) point_cloud_labels[3273].shape: (2571, 5) all_labels[3273][0]: body point_cloud_labels[3273][0]: [0. 1. 0. 0. 0.] </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>all_labels[3273][1]: body point_cloud_labels[3273][1]: [0. 1. 0. 0. 0.] </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>all_labels[3273][2]: tail point_cloud_labels[3273][2]: [0. 0. 1. 0. 0.] </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>all_labels[3273][3]: wing point_cloud_labels[3273][3]: [1. 0. 0. 0. 0.] </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>all_labels[3273][4]: wing point_cloud_labels[3273][4]: [1. 0. 0. 0. 0.] </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>point_clouds[929].shape: (2571, 3) point_cloud_labels[929].shape: (2571, 5) all_labels[929][0]: body point_cloud_labels[929][0]: [0. 1. 0. 0. 0.] </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>all_labels[929][1]: tail point_cloud_labels[929][1]: [0. 0. 1. 0. 0.] </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>all_labels[929][2]: wing point_cloud_labels[929][2]: [1. 0. 0. 0. 0.] </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>all_labels[929][3]: tail point_cloud_labels[929][3]: [0. 0. 1. 0. 0.] </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>all_labels[929][4]: body point_cloud_labels[929][4]: [0. 1. 0. 0. 0.] </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>point_clouds[496].shape: (2571, 3) point_cloud_labels[496].shape: (2571, 5) all_labels[496][0]: body point_cloud_labels[496][0]: [0. 1. 0. 0. 0.] </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>all_labels[496][1]: body point_cloud_labels[496][1]: [0. 1. 0. 0. 0.] </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>all_labels[496][2]: body point_cloud_labels[496][2]: [0. 1. 0. 0. 0.] </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>all_labels[496][3]: wing point_cloud_labels[496][3]: [1. 0. 0. 0. 0.] </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>all_labels[496][4]: body point_cloud_labels[496][4]: [0. 1. 0. 0. 0.] </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>point_clouds[3508].shape: (2571, 3) point_cloud_labels[3508].shape: (2571, 5) all_labels[3508][0]: body point_cloud_labels[3508][0]: [0. 1. 0. 0. 0.] </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>all_labels[3508][1]: body point_cloud_labels[3508][1]: [0. 1. 0. 0. 0.] </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>all_labels[3508][2]: body point_cloud_labels[3508][2]: [0. 1. 0. 0. 0.] </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>all_labels[3508][3]: body point_cloud_labels[3508][3]: [0. 1. 0. 0. 0.] </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>all_labels[3508][4]: body point_cloud_labels[3508][4]: [0. 1. 0. 0. 0.] </code></pre></div> </div> <p>Now, let's visualize some of the point clouds along with their labels.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">visualize_data</span><span class="p">(</span><span class="n">point_cloud</span><span class="p">,</span> <span class="n">labels</span><span class="p">):</span> <span class="n">df</span> <span class="o">=</span> <span class="n">pd</span><span class="o">.</span><span class="n">DataFrame</span><span class="p">(</span> <span class="n">data</span><span class="o">=</span><span class="p">{</span> <span class="s2">"x"</span><span class="p">:</span> <span class="n">point_cloud</span><span class="p">[:,</span> <span class="mi">0</span><span class="p">],</span> <span class="s2">"y"</span><span class="p">:</span> <span class="n">point_cloud</span><span class="p">[:,</span> <span class="mi">1</span><span class="p">],</span> <span class="s2">"z"</span><span class="p">:</span> <span class="n">point_cloud</span><span class="p">[:,</span> <span class="mi">2</span><span class="p">],</span> <span class="s2">"label"</span><span class="p">:</span> <span class="n">labels</span><span class="p">,</span> <span class="p">}</span> <span class="p">)</span> <span class="n">fig</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">15</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">axes</span><span class="p">(</span><span class="n">projection</span><span class="o">=</span><span class="s2">"3d"</span><span class="p">)</span> <span class="k">for</span> <span class="n">index</span><span class="p">,</span> <span class="n">label</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">LABELS</span><span class="p">):</span> <span class="n">c_df</span> <span class="o">=</span> <span class="n">df</span><span class="p">[</span><span class="n">df</span><span class="p">[</span><span class="s2">"label"</span><span class="p">]</span> <span class="o">==</span> <span class="n">label</span><span class="p">]</span> <span class="k">try</span><span class="p">:</span> <span class="n">ax</span><span class="o">.</span><span class="n">scatter</span><span class="p">(</span> <span class="n">c_df</span><span class="p">[</span><span class="s2">"x"</span><span class="p">],</span> <span class="n">c_df</span><span class="p">[</span><span class="s2">"y"</span><span class="p">],</span> <span class="n">c_df</span><span class="p">[</span><span class="s2">"z"</span><span class="p">],</span> <span class="n">label</span><span class="o">=</span><span class="n">label</span><span class="p">,</span> <span class="n">alpha</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span> <span class="n">c</span><span class="o">=</span><span class="n">COLORS</span><span class="p">[</span><span class="n">index</span><span class="p">]</span> <span class="p">)</span> <span class="k">except</span> <span class="ne">IndexError</span><span class="p">:</span> <span class="k">pass</span> <span class="n">ax</span><span class="o">.</span><span class="n">legend</span><span class="p">()</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> <span class="n">visualize_data</span><span class="p">(</span><span class="n">point_clouds</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">all_labels</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="n">visualize_data</span><span class="p">(</span><span class="n">point_clouds</span><span class="p">[</span><span class="mi">300</span><span class="p">],</span> <span class="n">all_labels</span><span class="p">[</span><span class="mi">300</span><span class="p">])</span> </code></pre></div> <p><img alt="png" src="/img/examples/vision/pointnet_segmentation/pointnet_segmentation_15_0.png" /></p> <p><img alt="png" src="/img/examples/vision/pointnet_segmentation/pointnet_segmentation_15_1.png" /></p> <h3 id="preprocessing">Preprocessing</h3> <p>Note that all the point clouds that we have loaded consist of a variable number of points, which makes it difficult for us to batch them together. In order to overcome this problem, we randomly sample a fixed number of points from each point cloud. We also normalize the point clouds in order to make the data scale-invariant.</p> <div class="codehilite"><pre><span></span><code><span class="k">for</span> <span class="n">index</span> <span class="ow">in</span> <span class="n">tqdm</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">point_clouds</span><span class="p">))):</span> <span class="n">current_point_cloud</span> <span class="o">=</span> <span class="n">point_clouds</span><span class="p">[</span><span class="n">index</span><span class="p">]</span> <span class="n">current_label_cloud</span> <span class="o">=</span> <span class="n">point_cloud_labels</span><span class="p">[</span><span class="n">index</span><span class="p">]</span> <span class="n">current_labels</span> <span class="o">=</span> <span class="n">all_labels</span><span class="p">[</span><span class="n">index</span><span class="p">]</span> <span class="n">num_points</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">current_point_cloud</span><span class="p">)</span> <span class="c1"># Randomly sampling respective indices.</span> <span class="n">sampled_indices</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">sample</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="n">num_points</span><span class="p">)),</span> <span class="n">NUM_SAMPLE_POINTS</span><span class="p">)</span> <span class="c1"># Sampling points corresponding to sampled indices.</span> <span class="n">sampled_point_cloud</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">current_point_cloud</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">sampled_indices</span><span class="p">])</span> <span class="c1"># Sampling corresponding one-hot encoded labels.</span> <span class="n">sampled_label_cloud</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">current_label_cloud</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">sampled_indices</span><span class="p">])</span> <span class="c1"># Sampling corresponding labels for visualization.</span> <span class="n">sampled_labels</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">current_labels</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">sampled_indices</span><span class="p">])</span> <span class="c1"># Normalizing sampled point cloud.</span> <span class="n">norm_point_cloud</span> <span class="o">=</span> <span class="n">sampled_point_cloud</span> <span class="o">-</span> <span class="n">np</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">sampled_point_cloud</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="n">norm_point_cloud</span> <span class="o">/=</span> <span class="n">np</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">linalg</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">norm_point_cloud</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">))</span> <span class="n">point_clouds</span><span class="p">[</span><span class="n">index</span><span class="p">]</span> <span class="o">=</span> <span class="n">norm_point_cloud</span> <span class="n">point_cloud_labels</span><span class="p">[</span><span class="n">index</span><span class="p">]</span> <span class="o">=</span> <span class="n">sampled_label_cloud</span> <span class="n">all_labels</span><span class="p">[</span><span class="n">index</span><span class="p">]</span> <span class="o">=</span> <span class="n">sampled_labels</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>100%|█████████████████████████████████████████████████████████████████████| 3694/3694 [00:08<00:00, 446.45it/s] </code></pre></div> </div> <p>Let's visualize the sampled and normalized point clouds along with their corresponding labels.</p> <div class="codehilite"><pre><span></span><code><span class="n">visualize_data</span><span class="p">(</span><span class="n">point_clouds</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">all_labels</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="n">visualize_data</span><span class="p">(</span><span class="n">point_clouds</span><span class="p">[</span><span class="mi">300</span><span class="p">],</span> <span class="n">all_labels</span><span class="p">[</span><span class="mi">300</span><span class="p">])</span> </code></pre></div> <p><img alt="png" src="/img/examples/vision/pointnet_segmentation/pointnet_segmentation_19_0.png" /></p> <p><img alt="png" src="/img/examples/vision/pointnet_segmentation/pointnet_segmentation_19_1.png" /></p> <h3 id="creating-tensorflow-datasets">Creating TensorFlow datasets</h3> <p>We create <a href="https://www.tensorflow.org/api_docs/python/tf/data/Dataset"><code>tf.data.Dataset</code></a> objects for the training and validation data. We also augment the training point clouds by applying random jitter to them.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">load_data</span><span class="p">(</span><span class="n">point_cloud_batch</span><span class="p">,</span> <span class="n">label_cloud_batch</span><span class="p">):</span> <span class="n">point_cloud_batch</span><span class="o">.</span><span class="n">set_shape</span><span class="p">([</span><span class="n">NUM_SAMPLE_POINTS</span><span class="p">,</span> <span class="mi">3</span><span class="p">])</span> <span class="n">label_cloud_batch</span><span class="o">.</span><span class="n">set_shape</span><span class="p">([</span><span class="n">NUM_SAMPLE_POINTS</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">LABELS</span><span class="p">)</span> <span class="o">+</span> <span class="mi">1</span><span class="p">])</span> <span class="k">return</span> <span class="n">point_cloud_batch</span><span class="p">,</span> <span class="n">label_cloud_batch</span> <span class="k">def</span> <span class="nf">augment</span><span class="p">(</span><span class="n">point_cloud_batch</span><span class="p">,</span> <span class="n">label_cloud_batch</span><span class="p">):</span> <span class="n">noise</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span> <span class="n">tf</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">label_cloud_batch</span><span class="p">),</span> <span class="o">-</span><span class="mf">0.001</span><span class="p">,</span> <span class="mf">0.001</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">float64</span> <span class="p">)</span> <span class="n">point_cloud_batch</span> <span class="o">+=</span> <span class="n">noise</span><span class="p">[:,</span> <span class="p">:,</span> <span class="p">:</span><span class="mi">3</span><span class="p">]</span> <span class="k">return</span> <span class="n">point_cloud_batch</span><span class="p">,</span> <span class="n">label_cloud_batch</span> <span class="k">def</span> <span class="nf">generate_dataset</span><span class="p">(</span><span class="n">point_clouds</span><span class="p">,</span> <span class="n">label_clouds</span><span class="p">,</span> <span class="n">is_training</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">from_tensor_slices</span><span class="p">((</span><span class="n">point_clouds</span><span class="p">,</span> <span class="n">label_clouds</span><span class="p">))</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="n">BATCH_SIZE</span> <span class="o">*</span> <span class="mi">100</span><span class="p">)</span> <span class="k">if</span> <span class="n">is_training</span> <span class="k">else</span> <span class="n">dataset</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">load_data</span><span class="p">,</span> <span class="n">num_parallel_calls</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">AUTOTUNE</span><span class="p">)</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="n">batch_size</span><span class="o">=</span><span class="n">BATCH_SIZE</span><span class="p">)</span> <span class="n">dataset</span> <span class="o">=</span> <span class="p">(</span> <span class="n">dataset</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">augment</span><span class="p">,</span> <span class="n">num_parallel_calls</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">AUTOTUNE</span><span class="p">)</span> <span class="k">if</span> <span class="n">is_training</span> <span class="k">else</span> <span class="n">dataset</span> <span class="p">)</span> <span class="k">return</span> <span class="n">dataset</span> <span class="n">split_index</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">point_clouds</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">VAL_SPLIT</span><span class="p">))</span> <span class="n">train_point_clouds</span> <span class="o">=</span> <span class="n">point_clouds</span><span class="p">[:</span><span class="n">split_index</span><span class="p">]</span> <span class="n">train_label_cloud</span> <span class="o">=</span> <span class="n">point_cloud_labels</span><span class="p">[:</span><span class="n">split_index</span><span class="p">]</span> <span class="n">total_training_examples</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">train_point_clouds</span><span class="p">)</span> <span class="n">val_point_clouds</span> <span class="o">=</span> <span class="n">point_clouds</span><span class="p">[</span><span class="n">split_index</span><span class="p">:]</span> <span class="n">val_label_cloud</span> <span class="o">=</span> <span class="n">point_cloud_labels</span><span class="p">[</span><span class="n">split_index</span><span class="p">:]</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Num train point clouds:"</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">train_point_clouds</span><span class="p">))</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Num train point cloud labels:"</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">train_label_cloud</span><span class="p">))</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Num val point clouds:"</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">val_point_clouds</span><span class="p">))</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Num val point cloud labels:"</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">val_label_cloud</span><span class="p">))</span> <span class="n">train_dataset</span> <span class="o">=</span> <span class="n">generate_dataset</span><span class="p">(</span><span class="n">train_point_clouds</span><span class="p">,</span> <span class="n">train_label_cloud</span><span class="p">)</span> <span class="n">val_dataset</span> <span class="o">=</span> <span class="n">generate_dataset</span><span class="p">(</span><span class="n">val_point_clouds</span><span class="p">,</span> <span class="n">val_label_cloud</span><span class="p">,</span> <span class="n">is_training</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Train Dataset:"</span><span class="p">,</span> <span class="n">train_dataset</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Validation Dataset:"</span><span class="p">,</span> <span class="n">val_dataset</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Num train point clouds: 2955 Num train point cloud labels: 2955 Num val point clouds: 739 Num val point cloud labels: 739 Train Dataset: <_ParallelMapDataset element_spec=(TensorSpec(shape=(None, 1024, 3), dtype=tf.float64, name=None), TensorSpec(shape=(None, 1024, 5), dtype=tf.float64, name=None))> Validation Dataset: <_BatchDataset element_spec=(TensorSpec(shape=(None, 1024, 3), dtype=tf.float64, name=None), TensorSpec(shape=(None, 1024, 5), dtype=tf.float64, name=None))> </code></pre></div> </div> <hr /> <h2 id="pointnet-model">PointNet model</h2> <p>The figure below depicts the internals of the PointNet model family:</p> <p><img alt="" src="https://i.imgur.com/qFLNw5L.png" /></p> <p>Given that PointNet is meant to consume an <strong><em>unordered set</em></strong> of coordinates as its input data, its architecture needs to match the following characteristic properties of point cloud data:</p> <h3 id="permutation-invariance">Permutation invariance</h3> <p>Given the unstructured nature of point cloud data, a scan made up of <code>n</code> points has <code>n!</code> permutations. The subsequent data processing must be invariant to the different representations. In order to make PointNet invariant to input permutations, we use a symmetric function (such as max-pooling) once the <code>n</code> input points are mapped to higher-dimensional space. The result is a <strong>global feature vector</strong> that aims to capture an aggregate signature of the <code>n</code> input points. The global feature vector is used alongside local point features for segmentation.</p> <p><img alt="" src="https://i.imgur.com/0mrvvjb.png" /></p> <h3 id="transformation-invariance">Transformation invariance</h3> <p>Segmentation outputs should be unchanged if the object undergoes certain transformations, such as translation or scaling. For a given input point cloud, we apply an appropriate rigid or affine transformation to achieve pose normalization. Because each of the <code>n</code> input points are represented as a vector and are mapped to the embedding spaces independently, applying a geometric transformation simply amounts to matrix multiplying each point with a transformation matrix. This is motivated by the concept of <a href="https://arxiv.org/abs/1506.02025">Spatial Transformer Networks</a>.</p> <p>The operations comprising the T-Net are motivated by the higher-level architecture of PointNet. MLPs (or fully-connected layers) are used to map the input points independently and identically to a higher-dimensional space; max-pooling is used to encode a global feature vector whose dimensionality is then reduced with fully-connected layers. The input-dependent features at the final fully-connected layer are then combined with globally trainable weights and biases, resulting in a 3-by-3 transformation matrix.</p> <p><img alt="" src="https://i.imgur.com/aEj3GYi.png" /></p> <h3 id="point-interactions">Point interactions</h3> <p>The interaction between neighboring points often carries useful information (i.e., a single point should not be treated in isolation). Whereas classification need only make use of global features, segmentation must be able to leverage local point features along with global point features.</p> <p><strong>Note</strong>: The figures presented in this section have been taken from the <a href="https://arxiv.org/abs/1612.00593">original paper</a>.</p> <p>Now that we know the pieces that compose the PointNet model, we can implement the model. We start by implementing the basic blocks i.e., the convolutional block and the multi-layer perceptron block.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">conv_block</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">filters</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv1D</span><span class="p">(</span><span class="n">filters</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"valid"</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2">_conv"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">BatchNormalization</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2">_batch_norm"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">layers</span><span class="o">.</span><span class="n">Activation</span><span class="p">(</span><span class="s2">"relu"</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2">_relu"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="k">def</span> <span class="nf">mlp_block</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">filters</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">filters</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2">_dense"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">BatchNormalization</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2">_batch_norm"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">layers</span><span class="o">.</span><span class="n">Activation</span><span class="p">(</span><span class="s2">"relu"</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2">_relu"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> </code></pre></div> <p>We implement a regularizer (taken from <a href="https://keras.io/examples/vision/pointnet/#build-a-model">this example</a>) to enforce orthogonality in the feature space. This is needed to ensure that the magnitudes of the transformed features do not vary too much.</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">OrthogonalRegularizer</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">regularizers</span><span class="o">.</span><span class="n">Regularizer</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Reference: https://keras.io/examples/vision/pointnet/#build-a-model"""</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">num_features</span><span class="p">,</span> <span class="n">l2reg</span><span class="o">=</span><span class="mf">0.001</span><span class="p">):</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_features</span> <span class="o">=</span> <span class="n">num_features</span> <span class="bp">self</span><span class="o">.</span><span class="n">l2reg</span> <span class="o">=</span> <span class="n">l2reg</span> <span class="bp">self</span><span class="o">.</span><span class="n">identity</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span><span class="n">num_features</span><span class="p">)</span> <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_features</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_features</span><span class="p">))</span> <span class="n">xxt</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">tensordot</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">axes</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">))</span> <span class="n">xxt</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">xxt</span><span class="p">,</span> <span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_features</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_features</span><span class="p">))</span> <span class="k">return</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">l2reg</span> <span class="o">*</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">square</span><span class="p">(</span><span class="n">xxt</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">identity</span><span class="p">))</span> <span class="k">def</span> <span class="nf">get_config</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="n">config</span> <span class="o">=</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">get_config</span><span class="p">()</span> <span class="n">config</span><span class="o">.</span><span class="n">update</span><span class="p">({</span><span class="s2">"num_features"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_features</span><span class="p">,</span> <span class="s2">"l2reg_strength"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">l2reg</span><span class="p">})</span> <span class="k">return</span> <span class="n">config</span> </code></pre></div> <p>The next piece is the transformation network which we explained earlier.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">transformation_net</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">num_features</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""</span> <span class="sd"> Reference: https://keras.io/examples/vision/pointnet/#build-a-model.</span> <span class="sd"> The `filters` values come from the original paper:</span> <span class="sd"> https://arxiv.org/abs/1612.00593.</span> <span class="sd"> """</span> <span class="n">x</span> <span class="o">=</span> <span class="n">conv_block</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">filters</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2">_1"</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">conv_block</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">filters</span><span class="o">=</span><span class="mi">128</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2">_2"</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">conv_block</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">filters</span><span class="o">=</span><span class="mi">1024</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2">_3"</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">GlobalMaxPooling1D</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">mlp_block</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">filters</span><span class="o">=</span><span class="mi">512</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2">_1_1"</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">mlp_block</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">filters</span><span class="o">=</span><span class="mi">256</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2">_2_1"</span><span class="p">)</span> <span class="k">return</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span> <span class="n">num_features</span> <span class="o">*</span> <span class="n">num_features</span><span class="p">,</span> <span class="n">kernel_initializer</span><span class="o">=</span><span class="s2">"zeros"</span><span class="p">,</span> <span class="n">bias_initializer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">initializers</span><span class="o">.</span><span class="n">Constant</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span><span class="n">num_features</span><span class="p">)</span><span class="o">.</span><span class="n">flatten</span><span class="p">()),</span> <span class="n">activity_regularizer</span><span class="o">=</span><span class="n">OrthogonalRegularizer</span><span class="p">(</span><span class="n">num_features</span><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2">_final"</span><span class="p">,</span> <span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="k">def</span> <span class="nf">transformation_block</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">num_features</span><span class="p">,</span> <span class="n">name</span><span class="p">):</span> <span class="n">transformed_features</span> <span class="o">=</span> <span class="n">transformation_net</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">num_features</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">)</span> <span class="n">transformed_features</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Reshape</span><span class="p">((</span><span class="n">num_features</span><span class="p">,</span> <span class="n">num_features</span><span class="p">))(</span> <span class="n">transformed_features</span> <span class="p">)</span> <span class="k">return</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dot</span><span class="p">(</span><span class="n">axes</span><span class="o">=</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="sa">f</span><span class="s2">"</span><span class="si">{</span><span class="n">name</span><span class="si">}</span><span class="s2">_mm"</span><span class="p">)([</span><span class="n">inputs</span><span class="p">,</span> <span class="n">transformed_features</span><span class="p">])</span> </code></pre></div> <p>Finally, we piece the above blocks together and implement the segmentation model.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">get_shape_segmentation_model</span><span class="p">(</span><span class="n">num_points</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">):</span> <span class="n">input_points</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="kc">None</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span> <span class="c1"># PointNet Classification Network.</span> <span class="n">transformed_inputs</span> <span class="o">=</span> <span class="n">transformation_block</span><span class="p">(</span> <span class="n">input_points</span><span class="p">,</span> <span class="n">num_features</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"input_transformation_block"</span> <span class="p">)</span> <span class="n">features_64</span> <span class="o">=</span> <span class="n">conv_block</span><span class="p">(</span><span class="n">transformed_inputs</span><span class="p">,</span> <span class="n">filters</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"features_64"</span><span class="p">)</span> <span class="n">features_128_1</span> <span class="o">=</span> <span class="n">conv_block</span><span class="p">(</span><span class="n">features_64</span><span class="p">,</span> <span class="n">filters</span><span class="o">=</span><span class="mi">128</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"features_128_1"</span><span class="p">)</span> <span class="n">features_128_2</span> <span class="o">=</span> <span class="n">conv_block</span><span class="p">(</span><span class="n">features_128_1</span><span class="p">,</span> <span class="n">filters</span><span class="o">=</span><span class="mi">128</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"features_128_2"</span><span class="p">)</span> <span class="n">transformed_features</span> <span class="o">=</span> <span class="n">transformation_block</span><span class="p">(</span> <span class="n">features_128_2</span><span class="p">,</span> <span class="n">num_features</span><span class="o">=</span><span class="mi">128</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"transformed_features"</span> <span class="p">)</span> <span class="n">features_512</span> <span class="o">=</span> <span class="n">conv_block</span><span class="p">(</span><span class="n">transformed_features</span><span class="p">,</span> <span class="n">filters</span><span class="o">=</span><span class="mi">512</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"features_512"</span><span class="p">)</span> <span class="n">features_2048</span> <span class="o">=</span> <span class="n">conv_block</span><span class="p">(</span><span class="n">features_512</span><span class="p">,</span> <span class="n">filters</span><span class="o">=</span><span class="mi">2048</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"pre_maxpool_block"</span><span class="p">)</span> <span class="n">global_features</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">MaxPool1D</span><span class="p">(</span><span class="n">pool_size</span><span class="o">=</span><span class="n">num_points</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"global_features"</span><span class="p">)(</span> <span class="n">features_2048</span> <span class="p">)</span> <span class="n">global_features</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">tile</span><span class="p">(</span><span class="n">global_features</span><span class="p">,</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="n">num_points</span><span class="p">,</span> <span class="mi">1</span><span class="p">])</span> <span class="c1"># Segmentation head.</span> <span class="n">segmentation_input</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Concatenate</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">"segmentation_input"</span><span class="p">)(</span> <span class="p">[</span> <span class="n">features_64</span><span class="p">,</span> <span class="n">features_128_1</span><span class="p">,</span> <span class="n">features_128_2</span><span class="p">,</span> <span class="n">transformed_features</span><span class="p">,</span> <span class="n">features_512</span><span class="p">,</span> <span class="n">global_features</span><span class="p">,</span> <span class="p">]</span> <span class="p">)</span> <span class="n">segmentation_features</span> <span class="o">=</span> <span class="n">conv_block</span><span class="p">(</span> <span class="n">segmentation_input</span><span class="p">,</span> <span class="n">filters</span><span class="o">=</span><span class="mi">128</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"segmentation_features"</span> <span class="p">)</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv1D</span><span class="p">(</span> <span class="n">num_classes</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"softmax"</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"segmentation_head"</span> <span class="p">)(</span><span class="n">segmentation_features</span><span class="p">)</span> <span class="k">return</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">input_points</span><span class="p">,</span> <span class="n">outputs</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="instantiate-the-model">Instantiate the model</h2> <div class="codehilite"><pre><span></span><code><span class="n">x</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="nb">iter</span><span class="p">(</span><span class="n">train_dataset</span><span class="p">))</span> <span class="n">num_points</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="n">num_classes</span> <span class="o">=</span> <span class="n">y</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="n">segmentation_model</span> <span class="o">=</span> <span class="n">get_shape_segmentation_model</span><span class="p">(</span><span class="n">num_points</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">)</span> <span class="n">segmentation_model</span><span class="o">.</span><span class="n">summary</span><span class="p">()</span> </code></pre></div> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold">Model: "functional_1"</span> </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace">┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓ ┃<span style="font-weight: bold"> Layer (type) </span>┃<span style="font-weight: bold"> Output Shape </span>┃<span style="font-weight: bold"> Param # </span>┃<span style="font-weight: bold"> Connected to </span>┃ ┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩ │ input_layer │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">3</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ - │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">InputLayer</span>) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ input_transformati… │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">256</span> │ input_layer[<span style="color: #00af00; text-decoration-color: #00af00">0</span>][<span style="color: #00af00; text-decoration-color: #00af00">0</span>] │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv1D</span>) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ input_transformati… │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">256</span> │ input_transformatio… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">BatchNormalizatio…</span> │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ input_transformati… │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ input_transformatio… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Activation</span>) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ input_transformati… │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">8,320</span> │ input_transformatio… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv1D</span>) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ input_transformati… │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">512</span> │ input_transformatio… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">BatchNormalizatio…</span> │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ input_transformati… │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ input_transformatio… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Activation</span>) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ input_transformati… │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, │ <span style="color: #00af00; text-decoration-color: #00af00">132,096</span> │ input_transformatio… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv1D</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">1024</span>) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ input_transformati… │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, │ <span style="color: #00af00; text-decoration-color: #00af00">4,096</span> │ input_transformatio… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">BatchNormalizatio…</span> │ <span style="color: #00af00; text-decoration-color: #00af00">1024</span>) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ input_transformati… │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ input_transformatio… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Activation</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">1024</span>) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ global_max_pooling… │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">1024</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ input_transformatio… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">GlobalMaxPooling1…</span> │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ input_transformati… │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">512</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">524,800</span> │ global_max_pooling1… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Dense</span>) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ input_transformati… │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">512</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">2,048</span> │ input_transformatio… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">BatchNormalizatio…</span> │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ input_transformati… │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">512</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ input_transformatio… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Activation</span>) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ input_transformati… │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">256</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">131,328</span> │ input_transformatio… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Dense</span>) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ input_transformati… │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">256</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">1,024</span> │ input_transformatio… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">BatchNormalizatio…</span> │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ input_transformati… │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">256</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ input_transformatio… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Activation</span>) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ input_transformati… │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">9</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">2,313</span> │ input_transformatio… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Dense</span>) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ reshape (<span style="color: #0087ff; text-decoration-color: #0087ff">Reshape</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">3</span>, <span style="color: #00af00; text-decoration-color: #00af00">3</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ input_transformatio… │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ input_transformati… │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">3</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ input_layer[<span style="color: #00af00; text-decoration-color: #00af00">0</span>][<span style="color: #00af00; text-decoration-color: #00af00">0</span>], │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Dot</span>) │ │ │ reshape[<span style="color: #00af00; text-decoration-color: #00af00">0</span>][<span style="color: #00af00; text-decoration-color: #00af00">0</span>] │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ features_64_conv │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">256</span> │ input_transformatio… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv1D</span>) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ features_64_batch_… │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">256</span> │ features_64_conv[<span style="color: #00af00; text-decoration-color: #00af00">0</span>]… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">BatchNormalizatio…</span> │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ features_64_relu │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ features_64_batch_n… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Activation</span>) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ features_128_1_conv │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">8,320</span> │ features_64_relu[<span style="color: #00af00; text-decoration-color: #00af00">0</span>]… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv1D</span>) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ features_128_1_bat… │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">512</span> │ features_128_1_conv… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">BatchNormalizatio…</span> │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ features_128_1_relu │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ features_128_1_batc… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Activation</span>) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ features_128_2_conv │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">16,512</span> │ features_128_1_relu… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv1D</span>) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ features_128_2_bat… │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">512</span> │ features_128_2_conv… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">BatchNormalizatio…</span> │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ features_128_2_relu │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ features_128_2_batc… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Activation</span>) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ transformed_featur… │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">8,256</span> │ features_128_2_relu… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv1D</span>) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ transformed_featur… │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">256</span> │ transformed_feature… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">BatchNormalizatio…</span> │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ transformed_featur… │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ transformed_feature… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Activation</span>) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ transformed_featur… │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">8,320</span> │ transformed_feature… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv1D</span>) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ transformed_featur… │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">512</span> │ transformed_feature… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">BatchNormalizatio…</span> │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ transformed_featur… │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ transformed_feature… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Activation</span>) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ transformed_featur… │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, │ <span style="color: #00af00; text-decoration-color: #00af00">132,096</span> │ transformed_feature… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv1D</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">1024</span>) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ transformed_featur… │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, │ <span style="color: #00af00; text-decoration-color: #00af00">4,096</span> │ transformed_feature… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">BatchNormalizatio…</span> │ <span style="color: #00af00; text-decoration-color: #00af00">1024</span>) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ transformed_featur… │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ transformed_feature… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Activation</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">1024</span>) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ global_max_pooling… │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">1024</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ transformed_feature… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">GlobalMaxPooling1…</span> │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ transformed_featur… │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">512</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">524,800</span> │ global_max_pooling1… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Dense</span>) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ transformed_featur… │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">512</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">2,048</span> │ transformed_feature… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">BatchNormalizatio…</span> │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ transformed_featur… │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">512</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ transformed_feature… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Activation</span>) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ transformed_featur… │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">256</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">131,328</span> │ transformed_feature… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Dense</span>) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ transformed_featur… │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">256</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">1,024</span> │ transformed_feature… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">BatchNormalizatio…</span> │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ transformed_featur… │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">256</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ transformed_feature… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Activation</span>) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ transformed_featur… │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">16384</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">4,210,…</span> │ transformed_feature… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Dense</span>) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ reshape_1 (<span style="color: #0087ff; text-decoration-color: #0087ff">Reshape</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ transformed_feature… │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ transformed_featur… │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ features_128_2_relu… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Dot</span>) │ │ │ reshape_1[<span style="color: #00af00; text-decoration-color: #00af00">0</span>][<span style="color: #00af00; text-decoration-color: #00af00">0</span>] │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ features_512_conv │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">512</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">66,048</span> │ transformed_feature… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv1D</span>) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ features_512_batch… │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">512</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">2,048</span> │ features_512_conv[<span style="color: #00af00; text-decoration-color: #00af00">0</span>… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">BatchNormalizatio…</span> │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ features_512_relu │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">512</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ features_512_batch_… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Activation</span>) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ pre_maxpool_block_… │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, │ <span style="color: #00af00; text-decoration-color: #00af00">1,050,…</span> │ features_512_relu[<span style="color: #00af00; text-decoration-color: #00af00">0</span>… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv1D</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">2048</span>) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ pre_maxpool_block_… │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, │ <span style="color: #00af00; text-decoration-color: #00af00">8,192</span> │ pre_maxpool_block_c… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">BatchNormalizatio…</span> │ <span style="color: #00af00; text-decoration-color: #00af00">2048</span>) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ pre_maxpool_block_… │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ pre_maxpool_block_b… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Activation</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">2048</span>) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ global_features │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ pre_maxpool_block_r… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">MaxPooling1D</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">2048</span>) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ tile (<span style="color: #0087ff; text-decoration-color: #0087ff">Tile</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ global_features[<span style="color: #00af00; text-decoration-color: #00af00">0</span>][<span style="color: #00af00; text-decoration-color: #00af00">…</span> │ │ │ <span style="color: #00af00; text-decoration-color: #00af00">2048</span>) │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ segmentation_input │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ features_64_relu[<span style="color: #00af00; text-decoration-color: #00af00">0</span>]… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Concatenate</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">3008</span>) │ │ features_128_1_relu… │ │ │ │ │ features_128_2_relu… │ │ │ │ │ transformed_feature… │ │ │ │ │ features_512_relu[<span style="color: #00af00; text-decoration-color: #00af00">0</span>… │ │ │ │ │ tile[<span style="color: #00af00; text-decoration-color: #00af00">0</span>][<span style="color: #00af00; text-decoration-color: #00af00">0</span>] │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ segmentation_featu… │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">385,152</span> │ segmentation_input[<span style="color: #00af00; text-decoration-color: #00af00">…</span> │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv1D</span>) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ segmentation_featu… │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">512</span> │ segmentation_featur… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">BatchNormalizatio…</span> │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ segmentation_featu… │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ segmentation_featur… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Activation</span>) │ │ │ │ ├─────────────────────┼───────────────────┼─────────┼──────────────────────┤ │ segmentation_head │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">5</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">645</span> │ segmentation_featur… │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv1D</span>) │ │ │ │ └─────────────────────┴───────────────────┴─────────┴──────────────────────┘ </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Total params: </span><span style="color: #00af00; text-decoration-color: #00af00">7,370,062</span> (28.11 MB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">7,356,110</span> (28.06 MB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Non-trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">13,952</span> (54.50 KB) </pre> <hr /> <h2 id="training">Training</h2> <p>For the training the authors recommend using a learning rate schedule that decays the initial learning rate by half every 20 epochs. In this example, we use 5 epochs.</p> <div class="codehilite"><pre><span></span><code><span class="n">steps_per_epoch</span> <span class="o">=</span> <span class="n">total_training_examples</span> <span class="o">//</span> <span class="n">BATCH_SIZE</span> <span class="n">total_training_steps</span> <span class="o">=</span> <span class="n">steps_per_epoch</span> <span class="o">*</span> <span class="n">EPOCHS</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Steps per epoch: </span><span class="si">{</span><span class="n">steps_per_epoch</span><span class="si">}</span><span class="s2">."</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Total training steps: </span><span class="si">{</span><span class="n">total_training_steps</span><span class="si">}</span><span class="s2">."</span><span class="p">)</span> <span class="n">lr_schedule</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">schedules</span><span class="o">.</span><span class="n">ExponentialDecay</span><span class="p">(</span> <span class="n">initial_learning_rate</span><span class="o">=</span><span class="mf">0.003</span><span class="p">,</span> <span class="n">decay_steps</span><span class="o">=</span><span class="n">steps_per_epoch</span> <span class="o">*</span> <span class="mi">5</span><span class="p">,</span> <span class="n">decay_rate</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span> <span class="n">staircase</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="p">)</span> <span class="n">steps</span> <span class="o">=</span> <span class="nb">range</span><span class="p">(</span><span class="n">total_training_steps</span><span class="p">)</span> <span class="n">lrs</span> <span class="o">=</span> <span class="p">[</span><span class="n">lr_schedule</span><span class="p">(</span><span class="n">step</span><span class="p">)</span> <span class="k">for</span> <span class="n">step</span> <span class="ow">in</span> <span class="n">steps</span><span class="p">]</span> <span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">lrs</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s2">"Steps"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">ylabel</span><span class="p">(</span><span class="s2">"Learning Rate"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Steps per epoch: 92. Total training steps: 5520. </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/pointnet_segmentation/pointnet_segmentation_34_1.png" /></p> <p>Finally, we implement a utility for running our experiments and launch model training.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">run_experiment</span><span class="p">(</span><span class="n">epochs</span><span class="p">):</span> <span class="n">segmentation_model</span> <span class="o">=</span> <span class="n">get_shape_segmentation_model</span><span class="p">(</span><span class="n">num_points</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">)</span> <span class="n">segmentation_model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">learning_rate</span><span class="o">=</span><span class="n">lr_schedule</span><span class="p">),</span> <span class="n">loss</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">CategoricalCrossentropy</span><span class="p">(),</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="s2">"accuracy"</span><span class="p">],</span> <span class="p">)</span> <span class="n">checkpoint_filepath</span> <span class="o">=</span> <span class="s2">"checkpoint.weights.h5"</span> <span class="n">checkpoint_callback</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">callbacks</span><span class="o">.</span><span class="n">ModelCheckpoint</span><span class="p">(</span> <span class="n">checkpoint_filepath</span><span class="p">,</span> <span class="n">monitor</span><span class="o">=</span><span class="s2">"val_loss"</span><span class="p">,</span> <span class="n">save_best_only</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">save_weights_only</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="p">)</span> <span class="n">history</span> <span class="o">=</span> <span class="n">segmentation_model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span> <span class="n">train_dataset</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="n">val_dataset</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="n">epochs</span><span class="p">,</span> <span class="n">callbacks</span><span class="o">=</span><span class="p">[</span><span class="n">checkpoint_callback</span><span class="p">],</span> <span class="p">)</span> <span class="n">segmentation_model</span><span class="o">.</span><span class="n">load_weights</span><span class="p">(</span><span class="n">checkpoint_filepath</span><span class="p">)</span> <span class="k">return</span> <span class="n">segmentation_model</span><span class="p">,</span> <span class="n">history</span> <span class="n">segmentation_model</span><span class="p">,</span> <span class="n">history</span> <span class="o">=</span> <span class="n">run_experiment</span><span class="p">(</span><span class="n">epochs</span><span class="o">=</span><span class="n">EPOCHS</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 1/60 2/93 [37m━━━━━━━━━━━━━━━━━━━━ 7s 86ms/step - accuracy: 0.1427 - loss: 48748.8203 WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1699916678.434176 90326 device_compiler.h:187] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process. 93/93 ━━━━━━━━━━━━━━━━━━━━ 53s 259ms/step - accuracy: 0.3739 - loss: 27980.7305 - val_accuracy: 0.4340 - val_loss: 10361231.0000 Epoch 2/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 48s 82ms/step - accuracy: 0.6355 - loss: 339.9151 - val_accuracy: 0.3820 - val_loss: 19069320.0000 Epoch 3/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.6695 - loss: 281.5728 - val_accuracy: 0.2859 - val_loss: 15993839.0000 Epoch 4/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 89ms/step - accuracy: 0.6812 - loss: 253.0939 - val_accuracy: 0.2287 - val_loss: 9633191.0000 Epoch 5/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 89ms/step - accuracy: 0.6873 - loss: 231.1317 - val_accuracy: 0.3030 - val_loss: 6001454.0000 Epoch 6/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 88ms/step - accuracy: 0.6860 - loss: 216.6793 - val_accuracy: 0.0620 - val_loss: 1945100.8750 Epoch 7/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.6947 - loss: 210.2683 - val_accuracy: 0.4539 - val_loss: 7908162.5000 Epoch 8/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7014 - loss: 203.2560 - val_accuracy: 0.4035 - val_loss: 17741164.0000 Epoch 9/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7006 - loss: 197.3710 - val_accuracy: 0.1900 - val_loss: 34120616.0000 Epoch 10/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7047 - loss: 192.0777 - val_accuracy: 0.3391 - val_loss: 33157422.0000 Epoch 11/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7102 - loss: 188.4875 - val_accuracy: 0.3394 - val_loss: 4630613.5000 Epoch 12/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 89ms/step - accuracy: 0.7186 - loss: 184.9940 - val_accuracy: 0.1662 - val_loss: 487790.1250 Epoch 13/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 89ms/step - accuracy: 0.7175 - loss: 182.7206 - val_accuracy: 0.1602 - val_loss: 70590.3203 Epoch 14/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 88ms/step - accuracy: 0.7159 - loss: 180.5028 - val_accuracy: 0.1631 - val_loss: 16990.2324 Epoch 15/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 88ms/step - accuracy: 0.7201 - loss: 180.1674 - val_accuracy: 0.2318 - val_loss: 4992.7783 Epoch 16/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 88ms/step - accuracy: 0.7222 - loss: 176.5523 - val_accuracy: 0.6246 - val_loss: 647.5634 Epoch 17/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 88ms/step - accuracy: 0.7291 - loss: 175.6139 - val_accuracy: 0.6551 - val_loss: 324.0956 Epoch 18/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 88ms/step - accuracy: 0.7285 - loss: 175.0228 - val_accuracy: 0.6430 - val_loss: 257.9340 Epoch 19/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 88ms/step - accuracy: 0.7300 - loss: 172.7668 - val_accuracy: 0.6399 - val_loss: 253.2745 Epoch 20/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 89ms/step - accuracy: 0.7316 - loss: 172.9001 - val_accuracy: 0.6084 - val_loss: 232.9293 Epoch 21/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 89ms/step - accuracy: 0.7364 - loss: 170.8767 - val_accuracy: 0.6451 - val_loss: 191.7183 Epoch 22/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 88ms/step - accuracy: 0.7395 - loss: 171.4525 - val_accuracy: 0.6825 - val_loss: 180.2473 Epoch 23/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7392 - loss: 170.1975 - val_accuracy: 0.6095 - val_loss: 180.3243 Epoch 24/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 88ms/step - accuracy: 0.7362 - loss: 169.2144 - val_accuracy: 0.6017 - val_loss: 178.3013 Epoch 25/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7409 - loss: 169.2571 - val_accuracy: 0.6582 - val_loss: 178.3481 Epoch 26/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 89ms/step - accuracy: 0.7415 - loss: 167.7480 - val_accuracy: 0.6808 - val_loss: 177.8774 Epoch 27/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 89ms/step - accuracy: 0.7440 - loss: 167.7844 - val_accuracy: 0.7131 - val_loss: 176.5841 Epoch 28/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 88ms/step - accuracy: 0.7423 - loss: 167.5307 - val_accuracy: 0.6891 - val_loss: 176.1687 Epoch 29/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 88ms/step - accuracy: 0.7409 - loss: 166.4581 - val_accuracy: 0.7136 - val_loss: 174.9417 Epoch 30/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 88ms/step - accuracy: 0.7419 - loss: 165.9243 - val_accuracy: 0.7407 - val_loss: 173.0663 Epoch 31/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 88ms/step - accuracy: 0.7471 - loss: 166.9746 - val_accuracy: 0.7454 - val_loss: 172.9663 Epoch 32/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7472 - loss: 165.9707 - val_accuracy: 0.7480 - val_loss: 173.9868 Epoch 33/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7443 - loss: 165.9368 - val_accuracy: 0.7076 - val_loss: 174.4526 Epoch 34/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7496 - loss: 165.5322 - val_accuracy: 0.7441 - val_loss: 174.6099 Epoch 35/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7453 - loss: 164.2007 - val_accuracy: 0.7469 - val_loss: 174.2793 Epoch 36/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7503 - loss: 165.3418 - val_accuracy: 0.7469 - val_loss: 174.0812 Epoch 37/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7491 - loss: 164.4796 - val_accuracy: 0.7524 - val_loss: 173.9656 Epoch 38/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 10s 82ms/step - accuracy: 0.7489 - loss: 164.4573 - val_accuracy: 0.7516 - val_loss: 175.3401 Epoch 39/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7437 - loss: 163.4484 - val_accuracy: 0.7532 - val_loss: 173.8172 Epoch 40/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7507 - loss: 163.6720 - val_accuracy: 0.7537 - val_loss: 173.9127 Epoch 41/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7506 - loss: 164.0555 - val_accuracy: 0.7556 - val_loss: 173.0979 Epoch 42/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 89ms/step - accuracy: 0.7517 - loss: 164.1554 - val_accuracy: 0.7562 - val_loss: 172.8895 Epoch 43/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 10s 82ms/step - accuracy: 0.7527 - loss: 164.6351 - val_accuracy: 0.7567 - val_loss: 173.0476 Epoch 44/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 88ms/step - accuracy: 0.7505 - loss: 164.1568 - val_accuracy: 0.7571 - val_loss: 172.2751 Epoch 45/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 88ms/step - accuracy: 0.7500 - loss: 163.8129 - val_accuracy: 0.7579 - val_loss: 171.8897 Epoch 46/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7534 - loss: 163.6473 - val_accuracy: 0.7577 - val_loss: 172.5457 Epoch 47/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7510 - loss: 163.7318 - val_accuracy: 0.7580 - val_loss: 172.2256 Epoch 48/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7517 - loss: 163.3274 - val_accuracy: 0.7575 - val_loss: 172.3276 Epoch 49/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 89ms/step - accuracy: 0.7511 - loss: 163.5069 - val_accuracy: 0.7581 - val_loss: 171.2155 Epoch 50/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 89ms/step - accuracy: 0.7507 - loss: 163.7366 - val_accuracy: 0.7578 - val_loss: 171.1100 Epoch 51/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7519 - loss: 163.1190 - val_accuracy: 0.7580 - val_loss: 171.7971 Epoch 52/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 81ms/step - accuracy: 0.7510 - loss: 162.7351 - val_accuracy: 0.7579 - val_loss: 171.9780 Epoch 53/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7510 - loss: 162.9639 - val_accuracy: 0.7577 - val_loss: 171.6770 Epoch 54/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 88ms/step - accuracy: 0.7530 - loss: 162.7419 - val_accuracy: 0.7578 - val_loss: 170.5556 Epoch 55/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7515 - loss: 163.2893 - val_accuracy: 0.7582 - val_loss: 171.9172 Epoch 56/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7505 - loss: 164.2843 - val_accuracy: 0.7584 - val_loss: 171.9182 Epoch 57/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7498 - loss: 162.6679 - val_accuracy: 0.7587 - val_loss: 173.7610 Epoch 58/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7523 - loss: 163.3332 - val_accuracy: 0.7585 - val_loss: 172.5207 Epoch 59/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7529 - loss: 162.4575 - val_accuracy: 0.7586 - val_loss: 171.6861 Epoch 60/60 93/93 ━━━━━━━━━━━━━━━━━━━━ 8s 82ms/step - accuracy: 0.7498 - loss: 162.9523 - val_accuracy: 0.7586 - val_loss: 172.3012 </code></pre></div> </div> <hr /> <h2 id="visualize-the-training-landscape">Visualize the training landscape</h2> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">plot_result</span><span class="p">(</span><span class="n">item</span><span class="p">):</span> <span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">history</span><span class="o">.</span><span class="n">history</span><span class="p">[</span><span class="n">item</span><span class="p">],</span> <span class="n">label</span><span class="o">=</span><span class="n">item</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">history</span><span class="o">.</span><span class="n">history</span><span class="p">[</span><span class="s2">"val_"</span> <span class="o">+</span> <span class="n">item</span><span class="p">],</span> <span class="n">label</span><span class="o">=</span><span class="s2">"val_"</span> <span class="o">+</span> <span class="n">item</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s2">"Epochs"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">ylabel</span><span class="p">(</span><span class="n">item</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="s2">"Train and Validation </span><span class="si">{}</span><span class="s2"> Over Epochs"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">item</span><span class="p">),</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">14</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">legend</span><span class="p">()</span> <span class="n">plt</span><span class="o">.</span><span class="n">grid</span><span class="p">()</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> <span class="n">plot_result</span><span class="p">(</span><span class="s2">"loss"</span><span class="p">)</span> <span class="n">plot_result</span><span class="p">(</span><span class="s2">"accuracy"</span><span class="p">)</span> </code></pre></div> <p><img alt="png" src="/img/examples/vision/pointnet_segmentation/pointnet_segmentation_38_0.png" /></p> <p><img alt="png" src="/img/examples/vision/pointnet_segmentation/pointnet_segmentation_38_1.png" /></p> <hr /> <h2 id="inference">Inference</h2> <div class="codehilite"><pre><span></span><code><span class="n">validation_batch</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="nb">iter</span><span class="p">(</span><span class="n">val_dataset</span><span class="p">))</span> <span class="n">val_predictions</span> <span class="o">=</span> <span class="n">segmentation_model</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">validation_batch</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Validation prediction shape: </span><span class="si">{</span><span class="n">val_predictions</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> <span class="k">def</span> <span class="nf">visualize_single_point_cloud</span><span class="p">(</span><span class="n">point_clouds</span><span class="p">,</span> <span class="n">label_clouds</span><span class="p">,</span> <span class="n">idx</span><span class="p">):</span> <span class="n">label_map</span> <span class="o">=</span> <span class="n">LABELS</span> <span class="o">+</span> <span class="p">[</span><span class="s2">"none"</span><span class="p">]</span> <span class="n">point_cloud</span> <span class="o">=</span> <span class="n">point_clouds</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span> <span class="n">label_cloud</span> <span class="o">=</span> <span class="n">label_clouds</span><span class="p">[</span><span class="n">idx</span><span class="p">]</span> <span class="n">visualize_data</span><span class="p">(</span><span class="n">point_cloud</span><span class="p">,</span> <span class="p">[</span><span class="n">label_map</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">label</span><span class="p">)]</span> <span class="k">for</span> <span class="n">label</span> <span class="ow">in</span> <span class="n">label_cloud</span><span class="p">])</span> <span class="n">idx</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">validation_batch</span><span class="p">[</span><span class="mi">0</span><span class="p">]))</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Index selected: </span><span class="si">{</span><span class="n">idx</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> <span class="c1"># Plotting with ground-truth.</span> <span class="n">visualize_single_point_cloud</span><span class="p">(</span><span class="n">validation_batch</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">validation_batch</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="n">idx</span><span class="p">)</span> <span class="c1"># Plotting with predicted labels.</span> <span class="n">visualize_single_point_cloud</span><span class="p">(</span><span class="n">validation_batch</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">val_predictions</span><span class="p">,</span> <span class="n">idx</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 1s/step Validation prediction shape: (32, 1024, 5) Index selected: 26 </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/pointnet_segmentation/pointnet_segmentation_40_1.png" /></p> <p><img alt="png" src="/img/examples/vision/pointnet_segmentation/pointnet_segmentation_40_2.png" /></p> <hr /> <h2 id="final-notes">Final notes</h2> <p>If you are interested in learning more about this topic, you may find <a href="https://github.com/soumik12345/point-cloud-segmentation">this repository</a> useful.</p> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#point-cloud-segmentation-with-pointnet'>Point cloud segmentation with PointNet</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-3'> <a href='#references'>References</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#imports'>Imports</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#downloading-dataset'>Downloading Dataset</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#loading-the-dataset'>Loading the dataset</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#structuring-the-dataset'>Structuring the dataset</a> </div> <div class='k-outline-depth-3'> <a href='#preprocessing'>Preprocessing</a> </div> <div class='k-outline-depth-3'> <a href='#creating-tensorflow-datasets'>Creating TensorFlow datasets</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#pointnet-model'>PointNet model</a> </div> <div class='k-outline-depth-3'> <a href='#permutation-invariance'>Permutation invariance</a> </div> <div class='k-outline-depth-3'> <a href='#transformation-invariance'>Transformation invariance</a> </div> <div class='k-outline-depth-3'> <a href='#point-interactions'>Point interactions</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#instantiate-the-model'>Instantiate the model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#training'>Training</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#visualize-the-training-landscape'>Visualize the training landscape</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#inference'>Inference</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#final-notes'>Final notes</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>