CINXE.COM

3D image classification from CT scans

<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/vision/3D_image_classification/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: 3D image classification from CT scans"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: 3D image classification from CT scans"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>3D image classification from CT scans</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink active" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink2" href="/examples/vision/image_classification_from_scratch/">Image classification from scratch</a> <a class="nav-sublink2" href="/examples/vision/mnist_convnet/">Simple MNIST convnet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_efficientnet_fine_tuning/">Image classification via fine-tuning with EfficientNet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_with_vision_transformer/">Image classification with Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/attention_mil_classification/">Classification using Attention-based Deep Multiple Instance Learning</a> <a class="nav-sublink2" href="/examples/vision/mlp_image_classification/">Image classification with modern MLP models</a> <a class="nav-sublink2" href="/examples/vision/mobilevit/">A mobile-friendly Transformer-based model for image classification</a> <a class="nav-sublink2" href="/examples/vision/xray_classification_with_tpus/">Pneumonia Classification on TPU</a> <a class="nav-sublink2" href="/examples/vision/cct/">Compact Convolutional Transformers</a> <a class="nav-sublink2" href="/examples/vision/convmixer/">Image classification with ConvMixer</a> <a class="nav-sublink2" href="/examples/vision/eanet/">Image classification with EANet (External Attention Transformer)</a> <a class="nav-sublink2" href="/examples/vision/involution/">Involutional neural networks</a> <a class="nav-sublink2" href="/examples/vision/perceiver_image_classification/">Image classification with Perceiver</a> <a class="nav-sublink2" href="/examples/vision/reptile/">Few-Shot learning with Reptile</a> <a class="nav-sublink2" href="/examples/vision/semisupervised_simclr/">Semi-supervised image classification using contrastive pretraining with SimCLR</a> <a class="nav-sublink2" href="/examples/vision/swin_transformers/">Image classification with Swin Transformers</a> <a class="nav-sublink2" href="/examples/vision/vit_small_ds/">Train a Vision Transformer on small datasets</a> <a class="nav-sublink2" href="/examples/vision/shiftvit/">A Vision Transformer without Attention</a> <a class="nav-sublink2" href="/examples/vision/image_classification_using_global_context_vision_transformer/">Image Classification using Global Context Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/oxford_pets_image_segmentation/">Image segmentation with a U-Net-like architecture</a> <a class="nav-sublink2" href="/examples/vision/deeplabv3_plus/">Multiclass semantic segmentation using DeepLabV3+</a> <a class="nav-sublink2" href="/examples/vision/basnet_segmentation/">Highly accurate boundaries segmentation using BASNet</a> <a class="nav-sublink2" href="/examples/vision/fully_convolutional_network/">Image Segmentation using Composable Fully-Convolutional Networks</a> <a class="nav-sublink2" href="/examples/vision/retinanet/">Object Detection with RetinaNet</a> <a class="nav-sublink2" href="/examples/vision/keypoint_detection/">Keypoint Detection with Transfer Learning</a> <a class="nav-sublink2" href="/examples/vision/object_detection_using_vision_transformer/">Object detection with Vision Transformers</a> <a class="nav-sublink2 active" href="/examples/vision/3D_image_classification/">3D image classification from CT scans</a> <a class="nav-sublink2" href="/examples/vision/depth_estimation/">Monocular depth estimation</a> <a class="nav-sublink2" href="/examples/vision/nerf/">3D volumetric rendering with NeRF</a> <a class="nav-sublink2" href="/examples/vision/pointnet_segmentation/">Point cloud segmentation with PointNet</a> <a class="nav-sublink2" href="/examples/vision/pointnet/">Point cloud classification</a> <a class="nav-sublink2" href="/examples/vision/captcha_ocr/">OCR model for reading Captchas</a> <a class="nav-sublink2" href="/examples/vision/handwriting_recognition/">Handwriting recognition</a> <a class="nav-sublink2" href="/examples/vision/autoencoder/">Convolutional autoencoder for image denoising</a> <a class="nav-sublink2" href="/examples/vision/mirnet/">Low-light image enhancement using MIRNet</a> <a class="nav-sublink2" href="/examples/vision/super_resolution_sub_pixel/">Image Super-Resolution using an Efficient Sub-Pixel CNN</a> <a class="nav-sublink2" href="/examples/vision/edsr/">Enhanced Deep Residual Networks for single-image super-resolution</a> <a class="nav-sublink2" href="/examples/vision/zero_dce/">Zero-DCE for low-light image enhancement</a> <a class="nav-sublink2" href="/examples/vision/cutmix/">CutMix data augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/mixup/">MixUp augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/randaugment/">RandAugment for Image Classification for Improved Robustness</a> <a class="nav-sublink2" href="/examples/vision/image_captioning/">Image captioning</a> <a class="nav-sublink2" href="/examples/vision/nl_image_search/">Natural language image search with a Dual Encoder</a> <a class="nav-sublink2" href="/examples/vision/visualizing_what_convnets_learn/">Visualizing what convnets learn</a> <a class="nav-sublink2" href="/examples/vision/integrated_gradients/">Model interpretability with Integrated Gradients</a> <a class="nav-sublink2" href="/examples/vision/probing_vits/">Investigating Vision Transformer representations</a> <a class="nav-sublink2" href="/examples/vision/grad_cam/">Grad-CAM class activation visualization</a> <a class="nav-sublink2" href="/examples/vision/near_dup_search/">Near-duplicate image search</a> <a class="nav-sublink2" href="/examples/vision/semantic_image_clustering/">Semantic Image Clustering</a> <a class="nav-sublink2" href="/examples/vision/siamese_contrastive/">Image similarity estimation using a Siamese Network with a contrastive loss</a> <a class="nav-sublink2" href="/examples/vision/siamese_network/">Image similarity estimation using a Siamese Network with a triplet loss</a> <a class="nav-sublink2" href="/examples/vision/metric_learning/">Metric learning for image similarity search</a> <a class="nav-sublink2" href="/examples/vision/metric_learning_tf_similarity/">Metric learning for image similarity search using TensorFlow Similarity</a> <a class="nav-sublink2" href="/examples/vision/nnclr/">Self-supervised contrastive learning with NNCLR</a> <a class="nav-sublink2" href="/examples/vision/video_classification/">Video Classification with a CNN-RNN Architecture</a> <a class="nav-sublink2" href="/examples/vision/conv_lstm/">Next-Frame Video Prediction with Convolutional LSTMs</a> <a class="nav-sublink2" href="/examples/vision/video_transformers/">Video Classification with Transformers</a> <a class="nav-sublink2" href="/examples/vision/vivit/">Video Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/bit/">Image Classification using BigTransfer (BiT)</a> <a class="nav-sublink2" href="/examples/vision/gradient_centralization/">Gradient Centralization for Better Training Performance</a> <a class="nav-sublink2" href="/examples/vision/token_learner/">Learning to tokenize in Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/knowledge_distillation/">Knowledge Distillation</a> <a class="nav-sublink2" href="/examples/vision/fixres/">FixRes: Fixing train-test resolution discrepancy</a> <a class="nav-sublink2" href="/examples/vision/cait/">Class Attention Image Transformers with LayerScale</a> <a class="nav-sublink2" href="/examples/vision/patch_convnet/">Augmenting convnets with aggregated attention</a> <a class="nav-sublink2" href="/examples/vision/learnable_resizer/">Learning to Resize</a> <a class="nav-sublink2" href="/examples/vision/adamatch/">Semi-supervision and domain adaptation with AdaMatch</a> <a class="nav-sublink2" href="/examples/vision/barlow_twins/">Barlow Twins for Contrastive SSL</a> <a class="nav-sublink2" href="/examples/vision/consistency_training/">Consistency training with supervision</a> <a class="nav-sublink2" href="/examples/vision/deit/">Distilling Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/focal_modulation_network/">Focal Modulation: A replacement for Self-Attention</a> <a class="nav-sublink2" href="/examples/vision/forwardforward/">Using the Forward-Forward Algorithm for Image Classification</a> <a class="nav-sublink2" href="/examples/vision/masked_image_modeling/">Masked image modeling with Autoencoders</a> <a class="nav-sublink2" href="/examples/vision/sam/">Segment Anything Model with 🤗Transformers</a> <a class="nav-sublink2" href="/examples/vision/segformer/">Semantic segmentation with SegFormer and Hugging Face Transformers</a> <a class="nav-sublink2" href="/examples/vision/simsiam/">Self-supervised contrastive learning with SimSiam</a> <a class="nav-sublink2" href="/examples/vision/supervised-contrastive-learning/">Supervised Contrastive Learning</a> <a class="nav-sublink2" href="/examples/vision/temporal_latent_bottleneck/">When Recurrence meets Transformers</a> <a class="nav-sublink2" href="/examples/vision/yolov8/">Efficient Object Detection with YOLOV8 and KerasCV</a> <a class="nav-sublink" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparameter Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> <a class="nav-link" href="/keras_cv/" role="tab" aria-selected="">KerasCV: Computer Vision Workflows</a> <a class="nav-link" href="/keras_nlp/" role="tab" aria-selected="">KerasNLP: Natural Language Workflows</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/vision/'>Computer Vision</a> / 3D image classification from CT scans </div> <div class='k-content'> <h1 id="3d-image-classification-from-ct-scans">3D image classification from CT scans</h1> <p><strong>Author:</strong> <a href="https://twitter.com/hasibzunair">Hasib Zunair</a><br> <strong>Date created:</strong> 2020/09/23<br> <strong>Last modified:</strong> 2024/01/11<br> <strong>Description:</strong> Train a 3D convolutional neural network to predict presence of pneumonia.</p> <div class='example_version_banner keras_3'>ⓘ This example uses Keras 3</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/vision/ipynb/3D_image_classification.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/examples/vision/3D_image_classification.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="introduction">Introduction</h2> <p>This example will show the steps needed to build a 3D convolutional neural network (CNN) to predict the presence of viral pneumonia in computer tomography (CT) scans. 2D CNNs are commonly used to process RGB images (3 channels). A 3D CNN is simply the 3D equivalent: it takes as input a 3D volume or a sequence of 2D frames (e.g. slices in a CT scan), 3D CNNs are a powerful model for learning representations for volumetric data.</p> <hr /> <h2 id="references">References</h2> <ul> <li><a href="https://arxiv.org/abs/1808.01462">A survey on Deep Learning Advances on Different 3D DataRepresentations</a></li> <li><a href="https://www.ri.cmu.edu/pub_files/2015/9/voxnet_maturana_scherer_iros15.pdf">VoxNet: A 3D Convolutional Neural Network for Real-Time Object Recognition</a></li> <li><a href="https://arxiv.org/abs/1607.05695">FusionNet: 3D Object Classification Using MultipleData Representations</a></li> <li><a href="https://arxiv.org/abs/2007.13224">Uniformizing Techniques to Process CT scans with 3D CNNs for Tuberculosis Prediction</a></li> </ul> <hr /> <h2 id="setup">Setup</h2> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">os</span> <span class="kn">import</span> <span class="nn">zipfile</span> <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> <span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="nn">tf</span> <span class="c1"># for data preprocessing</span> <span class="kn">import</span> <span class="nn">keras</span> <span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">layers</span> </code></pre></div> <hr /> <h2 id="downloading-the-mosmeddata-chest-ct-scans-with-covid19-related-findings">Downloading the MosMedData: Chest CT Scans with COVID-19 Related Findings</h2> <p>In this example, we use a subset of the <a href="https://www.medrxiv.org/content/10.1101/2020.05.20.20100362v1">MosMedData: Chest CT Scans with COVID-19 Related Findings</a>. This dataset consists of lung CT scans with COVID-19 related findings, as well as without such findings.</p> <p>We will be using the associated radiological findings of the CT scans as labels to build a classifier to predict presence of viral pneumonia. Hence, the task is a binary classification problem.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Download url of normal CT scans.</span> <span class="n">url</span> <span class="o">=</span> <span class="s2">&quot;https://github.com/hasibzunair/3D-image-classification-tutorial/releases/download/v0.2/CT-0.zip&quot;</span> <span class="n">filename</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">getcwd</span><span class="p">(),</span> <span class="s2">&quot;CT-0.zip&quot;</span><span class="p">)</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">get_file</span><span class="p">(</span><span class="n">filename</span><span class="p">,</span> <span class="n">url</span><span class="p">)</span> <span class="c1"># Download url of abnormal CT scans.</span> <span class="n">url</span> <span class="o">=</span> <span class="s2">&quot;https://github.com/hasibzunair/3D-image-classification-tutorial/releases/download/v0.2/CT-23.zip&quot;</span> <span class="n">filename</span> <span class="o">=</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">getcwd</span><span class="p">(),</span> <span class="s2">&quot;CT-23.zip&quot;</span><span class="p">)</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">get_file</span><span class="p">(</span><span class="n">filename</span><span class="p">,</span> <span class="n">url</span><span class="p">)</span> <span class="c1"># Make a directory to store the data.</span> <span class="n">os</span><span class="o">.</span><span class="n">makedirs</span><span class="p">(</span><span class="s2">&quot;MosMedData&quot;</span><span class="p">)</span> <span class="c1"># Unzip data in the newly created directory.</span> <span class="k">with</span> <span class="n">zipfile</span><span class="o">.</span><span class="n">ZipFile</span><span class="p">(</span><span class="s2">&quot;CT-0.zip&quot;</span><span class="p">,</span> <span class="s2">&quot;r&quot;</span><span class="p">)</span> <span class="k">as</span> <span class="n">z_fp</span><span class="p">:</span> <span class="n">z_fp</span><span class="o">.</span><span class="n">extractall</span><span class="p">(</span><span class="s2">&quot;./MosMedData/&quot;</span><span class="p">)</span> <span class="k">with</span> <span class="n">zipfile</span><span class="o">.</span><span class="n">ZipFile</span><span class="p">(</span><span class="s2">&quot;CT-23.zip&quot;</span><span class="p">,</span> <span class="s2">&quot;r&quot;</span><span class="p">)</span> <span class="k">as</span> <span class="n">z_fp</span><span class="p">:</span> <span class="n">z_fp</span><span class="o">.</span><span class="n">extractall</span><span class="p">(</span><span class="s2">&quot;./MosMedData/&quot;</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Downloading data from https://github.com/hasibzunair/3D-image-classification-tutorial/releases/download/v0.2/CT-0.zip </code></pre></div> </div> <p>1045162547/1045162547 ━━━━━━━━━━━━━━━━━━━━ 4s 0us/step</p> <hr /> <h2 id="loading-data-and-preprocessing">Loading data and preprocessing</h2> <p>The files are provided in Nifti format with the extension .nii. To read the scans, we use the <code>nibabel</code> package. You can install the package via <code>pip install nibabel</code>. CT scans store raw voxel intensity in Hounsfield units (HU). They range from -1024 to above 2000 in this dataset. Above 400 are bones with different radiointensity, so this is used as a higher bound. A threshold between -1000 and 400 is commonly used to normalize CT scans.</p> <p>To process the data, we do the following:</p> <ul> <li>We first rotate the volumes by 90 degrees, so the orientation is fixed</li> <li>We scale the HU values to be between 0 and 1.</li> <li>We resize width, height and depth.</li> </ul> <p>Here we define several helper functions to process the data. These functions will be used when building training and validation datasets.</p> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">nibabel</span> <span class="k">as</span> <span class="nn">nib</span> <span class="kn">from</span> <span class="nn">scipy</span> <span class="kn">import</span> <span class="n">ndimage</span> <span class="k">def</span> <span class="nf">read_nifti_file</span><span class="p">(</span><span class="n">filepath</span><span class="p">):</span> <span class="w"> </span><span class="sd">&quot;&quot;&quot;Read and load volume&quot;&quot;&quot;</span> <span class="c1"># Read file</span> <span class="n">scan</span> <span class="o">=</span> <span class="n">nib</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">filepath</span><span class="p">)</span> <span class="c1"># Get raw data</span> <span class="n">scan</span> <span class="o">=</span> <span class="n">scan</span><span class="o">.</span><span class="n">get_fdata</span><span class="p">()</span> <span class="k">return</span> <span class="n">scan</span> <span class="k">def</span> <span class="nf">normalize</span><span class="p">(</span><span class="n">volume</span><span class="p">):</span> <span class="w"> </span><span class="sd">&quot;&quot;&quot;Normalize the volume&quot;&quot;&quot;</span> <span class="nb">min</span> <span class="o">=</span> <span class="o">-</span><span class="mi">1000</span> <span class="nb">max</span> <span class="o">=</span> <span class="mi">400</span> <span class="n">volume</span><span class="p">[</span><span class="n">volume</span> <span class="o">&lt;</span> <span class="nb">min</span><span class="p">]</span> <span class="o">=</span> <span class="nb">min</span> <span class="n">volume</span><span class="p">[</span><span class="n">volume</span> <span class="o">&gt;</span> <span class="nb">max</span><span class="p">]</span> <span class="o">=</span> <span class="nb">max</span> <span class="n">volume</span> <span class="o">=</span> <span class="p">(</span><span class="n">volume</span> <span class="o">-</span> <span class="nb">min</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="nb">max</span> <span class="o">-</span> <span class="nb">min</span><span class="p">)</span> <span class="n">volume</span> <span class="o">=</span> <span class="n">volume</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">&quot;float32&quot;</span><span class="p">)</span> <span class="k">return</span> <span class="n">volume</span> <span class="k">def</span> <span class="nf">resize_volume</span><span class="p">(</span><span class="n">img</span><span class="p">):</span> <span class="w"> </span><span class="sd">&quot;&quot;&quot;Resize across z-axis&quot;&quot;&quot;</span> <span class="c1"># Set the desired depth</span> <span class="n">desired_depth</span> <span class="o">=</span> <span class="mi">64</span> <span class="n">desired_width</span> <span class="o">=</span> <span class="mi">128</span> <span class="n">desired_height</span> <span class="o">=</span> <span class="mi">128</span> <span class="c1"># Get current depth</span> <span class="n">current_depth</span> <span class="o">=</span> <span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="n">current_width</span> <span class="o">=</span> <span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="n">current_height</span> <span class="o">=</span> <span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="c1"># Compute depth factor</span> <span class="n">depth</span> <span class="o">=</span> <span class="n">current_depth</span> <span class="o">/</span> <span class="n">desired_depth</span> <span class="n">width</span> <span class="o">=</span> <span class="n">current_width</span> <span class="o">/</span> <span class="n">desired_width</span> <span class="n">height</span> <span class="o">=</span> <span class="n">current_height</span> <span class="o">/</span> <span class="n">desired_height</span> <span class="n">depth_factor</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="n">depth</span> <span class="n">width_factor</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="n">width</span> <span class="n">height_factor</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">/</span> <span class="n">height</span> <span class="c1"># Rotate</span> <span class="n">img</span> <span class="o">=</span> <span class="n">ndimage</span><span class="o">.</span><span class="n">rotate</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="mi">90</span><span class="p">,</span> <span class="n">reshape</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> <span class="c1"># Resize across z-axis</span> <span class="n">img</span> <span class="o">=</span> <span class="n">ndimage</span><span class="o">.</span><span class="n">zoom</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="p">(</span><span class="n">width_factor</span><span class="p">,</span> <span class="n">height_factor</span><span class="p">,</span> <span class="n">depth_factor</span><span class="p">),</span> <span class="n">order</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="k">return</span> <span class="n">img</span> <span class="k">def</span> <span class="nf">process_scan</span><span class="p">(</span><span class="n">path</span><span class="p">):</span> <span class="w"> </span><span class="sd">&quot;&quot;&quot;Read and resize volume&quot;&quot;&quot;</span> <span class="c1"># Read scan</span> <span class="n">volume</span> <span class="o">=</span> <span class="n">read_nifti_file</span><span class="p">(</span><span class="n">path</span><span class="p">)</span> <span class="c1"># Normalize</span> <span class="n">volume</span> <span class="o">=</span> <span class="n">normalize</span><span class="p">(</span><span class="n">volume</span><span class="p">)</span> <span class="c1"># Resize width, height and depth</span> <span class="n">volume</span> <span class="o">=</span> <span class="n">resize_volume</span><span class="p">(</span><span class="n">volume</span><span class="p">)</span> <span class="k">return</span> <span class="n">volume</span> </code></pre></div> <p>Let's read the paths of the CT scans from the class directories.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Folder &quot;CT-0&quot; consist of CT scans having normal lung tissue,</span> <span class="c1"># no CT-signs of viral pneumonia.</span> <span class="n">normal_scan_paths</span> <span class="o">=</span> <span class="p">[</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">getcwd</span><span class="p">(),</span> <span class="s2">&quot;MosMedData/CT-0&quot;</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">os</span><span class="o">.</span><span class="n">listdir</span><span class="p">(</span><span class="s2">&quot;MosMedData/CT-0&quot;</span><span class="p">)</span> <span class="p">]</span> <span class="c1"># Folder &quot;CT-23&quot; consist of CT scans having several ground-glass opacifications,</span> <span class="c1"># involvement of lung parenchyma.</span> <span class="n">abnormal_scan_paths</span> <span class="o">=</span> <span class="p">[</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="n">os</span><span class="o">.</span><span class="n">getcwd</span><span class="p">(),</span> <span class="s2">&quot;MosMedData/CT-23&quot;</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="k">for</span> <span class="n">x</span> <span class="ow">in</span> <span class="n">os</span><span class="o">.</span><span class="n">listdir</span><span class="p">(</span><span class="s2">&quot;MosMedData/CT-23&quot;</span><span class="p">)</span> <span class="p">]</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;CT scans with normal lung tissue: &quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">normal_scan_paths</span><span class="p">)))</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;CT scans with abnormal lung tissue: &quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">abnormal_scan_paths</span><span class="p">)))</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>CT scans with normal lung tissue: 100 CT scans with abnormal lung tissue: 100 </code></pre></div> </div> <hr /> <h2 id="build-train-and-validation-datasets">Build train and validation datasets</h2> <p>Read the scans from the class directories and assign labels. Downsample the scans to have shape of 128x128x64. Rescale the raw HU values to the range 0 to 1. Lastly, split the dataset into train and validation subsets.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Read and process the scans.</span> <span class="c1"># Each scan is resized across height, width, and depth and rescaled.</span> <span class="n">abnormal_scans</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">process_scan</span><span class="p">(</span><span class="n">path</span><span class="p">)</span> <span class="k">for</span> <span class="n">path</span> <span class="ow">in</span> <span class="n">abnormal_scan_paths</span><span class="p">])</span> <span class="n">normal_scans</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="n">process_scan</span><span class="p">(</span><span class="n">path</span><span class="p">)</span> <span class="k">for</span> <span class="n">path</span> <span class="ow">in</span> <span class="n">normal_scan_paths</span><span class="p">])</span> <span class="c1"># For the CT scans having presence of viral pneumonia</span> <span class="c1"># assign 1, for the normal ones assign 0.</span> <span class="n">abnormal_labels</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mi">1</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">abnormal_scans</span><span class="p">))])</span> <span class="n">normal_labels</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([</span><span class="mi">0</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">normal_scans</span><span class="p">))])</span> <span class="c1"># Split data in the ratio 70-30 for training and validation.</span> <span class="n">x_train</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">((</span><span class="n">abnormal_scans</span><span class="p">[:</span><span class="mi">70</span><span class="p">],</span> <span class="n">normal_scans</span><span class="p">[:</span><span class="mi">70</span><span class="p">]),</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="n">y_train</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">((</span><span class="n">abnormal_labels</span><span class="p">[:</span><span class="mi">70</span><span class="p">],</span> <span class="n">normal_labels</span><span class="p">[:</span><span class="mi">70</span><span class="p">]),</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="n">x_val</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">((</span><span class="n">abnormal_scans</span><span class="p">[</span><span class="mi">70</span><span class="p">:],</span> <span class="n">normal_scans</span><span class="p">[</span><span class="mi">70</span><span class="p">:]),</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="n">y_val</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">((</span><span class="n">abnormal_labels</span><span class="p">[</span><span class="mi">70</span><span class="p">:],</span> <span class="n">normal_labels</span><span class="p">[</span><span class="mi">70</span><span class="p">:]),</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span> <span class="s2">&quot;Number of samples in train and validation are </span><span class="si">%d</span><span class="s2"> and </span><span class="si">%d</span><span class="s2">.&quot;</span> <span class="o">%</span> <span class="p">(</span><span class="n">x_train</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">x_val</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Number of samples in train and validation are 140 and 60. </code></pre></div> </div> <hr /> <h2 id="data-augmentation">Data augmentation</h2> <p>The CT scans also augmented by rotating at random angles during training. Since the data is stored in rank-3 tensors of shape <code>(samples, height, width, depth)</code>, we add a dimension of size 1 at axis 4 to be able to perform 3D convolutions on the data. The new shape is thus <code>(samples, height, width, depth, 1)</code>. There are different kinds of preprocessing and augmentation techniques out there, this example shows a few simple ones to get started.</p> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">random</span> <span class="kn">from</span> <span class="nn">scipy</span> <span class="kn">import</span> <span class="n">ndimage</span> <span class="k">def</span> <span class="nf">rotate</span><span class="p">(</span><span class="n">volume</span><span class="p">):</span> <span class="w"> </span><span class="sd">&quot;&quot;&quot;Rotate the volume by a few degrees&quot;&quot;&quot;</span> <span class="k">def</span> <span class="nf">scipy_rotate</span><span class="p">(</span><span class="n">volume</span><span class="p">):</span> <span class="c1"># define some rotation angles</span> <span class="n">angles</span> <span class="o">=</span> <span class="p">[</span><span class="o">-</span><span class="mi">20</span><span class="p">,</span> <span class="o">-</span><span class="mi">10</span><span class="p">,</span> <span class="o">-</span><span class="mi">5</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">20</span><span class="p">]</span> <span class="c1"># pick angles at random</span> <span class="n">angle</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="n">angles</span><span class="p">)</span> <span class="c1"># rotate volume</span> <span class="n">volume</span> <span class="o">=</span> <span class="n">ndimage</span><span class="o">.</span><span class="n">rotate</span><span class="p">(</span><span class="n">volume</span><span class="p">,</span> <span class="n">angle</span><span class="p">,</span> <span class="n">reshape</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> <span class="n">volume</span><span class="p">[</span><span class="n">volume</span> <span class="o">&lt;</span> <span class="mi">0</span><span class="p">]</span> <span class="o">=</span> <span class="mi">0</span> <span class="n">volume</span><span class="p">[</span><span class="n">volume</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">]</span> <span class="o">=</span> <span class="mi">1</span> <span class="k">return</span> <span class="n">volume</span> <span class="n">augmented_volume</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">numpy_function</span><span class="p">(</span><span class="n">scipy_rotate</span><span class="p">,</span> <span class="p">[</span><span class="n">volume</span><span class="p">],</span> <span class="n">tf</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="k">return</span> <span class="n">augmented_volume</span> <span class="k">def</span> <span class="nf">train_preprocessing</span><span class="p">(</span><span class="n">volume</span><span class="p">,</span> <span class="n">label</span><span class="p">):</span> <span class="w"> </span><span class="sd">&quot;&quot;&quot;Process training data by rotating and adding a channel.&quot;&quot;&quot;</span> <span class="c1"># Rotate volume</span> <span class="n">volume</span> <span class="o">=</span> <span class="n">rotate</span><span class="p">(</span><span class="n">volume</span><span class="p">)</span> <span class="n">volume</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">volume</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span> <span class="k">return</span> <span class="n">volume</span><span class="p">,</span> <span class="n">label</span> <span class="k">def</span> <span class="nf">validation_preprocessing</span><span class="p">(</span><span class="n">volume</span><span class="p">,</span> <span class="n">label</span><span class="p">):</span> <span class="w"> </span><span class="sd">&quot;&quot;&quot;Process validation data by only adding a channel.&quot;&quot;&quot;</span> <span class="n">volume</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">volume</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span> <span class="k">return</span> <span class="n">volume</span><span class="p">,</span> <span class="n">label</span> </code></pre></div> <p>While defining the train and validation data loader, the training data is passed through and augmentation function which randomly rotates volume at different angles. Note that both training and validation data are already rescaled to have values between 0 and 1.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Define data loaders.</span> <span class="n">train_loader</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">from_tensor_slices</span><span class="p">((</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">))</span> <span class="n">validation_loader</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">from_tensor_slices</span><span class="p">((</span><span class="n">x_val</span><span class="p">,</span> <span class="n">y_val</span><span class="p">))</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="mi">2</span> <span class="c1"># Augment the on the fly during training.</span> <span class="n">train_dataset</span> <span class="o">=</span> <span class="p">(</span> <span class="n">train_loader</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">x_train</span><span class="p">))</span> <span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">train_preprocessing</span><span class="p">)</span> <span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span> <span class="o">.</span><span class="n">prefetch</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span> <span class="p">)</span> <span class="c1"># Only rescale.</span> <span class="n">validation_dataset</span> <span class="o">=</span> <span class="p">(</span> <span class="n">validation_loader</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">x_val</span><span class="p">))</span> <span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">validation_preprocessing</span><span class="p">)</span> <span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span> <span class="o">.</span><span class="n">prefetch</span><span class="p">(</span><span class="mi">2</span><span class="p">)</span> <span class="p">)</span> </code></pre></div> <p>Visualize an augmented CT scan.</p> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span> <span class="n">data</span> <span class="o">=</span> <span class="n">train_dataset</span><span class="o">.</span><span class="n">take</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span> <span class="n">images</span><span class="p">,</span> <span class="n">labels</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">data</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span> <span class="n">images</span> <span class="o">=</span> <span class="n">images</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> <span class="n">image</span> <span class="o">=</span> <span class="n">images</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Dimension of the CT scan is:&quot;</span><span class="p">,</span> <span class="n">image</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="n">image</span><span class="p">[:,</span> <span class="p">:,</span> <span class="mi">30</span><span class="p">]),</span> <span class="n">cmap</span><span class="o">=</span><span class="s2">&quot;gray&quot;</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Dimension of the CT scan is: (128, 128, 64, 1) &lt;matplotlib.image.AxesImage at 0x7fc5b9900d50&gt; </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/3D_image_classification/3D_image_classification_17_2.png" /></p> <p>Since a CT scan has many slices, let's visualize a montage of the slices.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">plot_slices</span><span class="p">(</span><span class="n">num_rows</span><span class="p">,</span> <span class="n">num_columns</span><span class="p">,</span> <span class="n">width</span><span class="p">,</span> <span class="n">height</span><span class="p">,</span> <span class="n">data</span><span class="p">):</span> <span class="w"> </span><span class="sd">&quot;&quot;&quot;Plot a montage of 20 CT slices&quot;&quot;&quot;</span> <span class="n">data</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">rot90</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">data</span><span class="p">))</span> <span class="n">data</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> <span class="n">data</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">data</span><span class="p">,</span> <span class="p">(</span><span class="n">num_rows</span><span class="p">,</span> <span class="n">num_columns</span><span class="p">,</span> <span class="n">width</span><span class="p">,</span> <span class="n">height</span><span class="p">))</span> <span class="n">rows_data</span><span class="p">,</span> <span class="n">columns_data</span> <span class="o">=</span> <span class="n">data</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">data</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="n">heights</span> <span class="o">=</span> <span class="p">[</span><span class="n">slc</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="k">for</span> <span class="n">slc</span> <span class="ow">in</span> <span class="n">data</span><span class="p">]</span> <span class="n">widths</span> <span class="o">=</span> <span class="p">[</span><span class="n">slc</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="k">for</span> <span class="n">slc</span> <span class="ow">in</span> <span class="n">data</span><span class="p">[</span><span class="mi">0</span><span class="p">]]</span> <span class="n">fig_width</span> <span class="o">=</span> <span class="mf">12.0</span> <span class="n">fig_height</span> <span class="o">=</span> <span class="n">fig_width</span> <span class="o">*</span> <span class="nb">sum</span><span class="p">(</span><span class="n">heights</span><span class="p">)</span> <span class="o">/</span> <span class="nb">sum</span><span class="p">(</span><span class="n">widths</span><span class="p">)</span> <span class="n">f</span><span class="p">,</span> <span class="n">axarr</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">(</span> <span class="n">rows_data</span><span class="p">,</span> <span class="n">columns_data</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="n">fig_width</span><span class="p">,</span> <span class="n">fig_height</span><span class="p">),</span> <span class="n">gridspec_kw</span><span class="o">=</span><span class="p">{</span><span class="s2">&quot;height_ratios&quot;</span><span class="p">:</span> <span class="n">heights</span><span class="p">},</span> <span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">rows_data</span><span class="p">):</span> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">columns_data</span><span class="p">):</span> <span class="n">axarr</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">]</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">data</span><span class="p">[</span><span class="n">i</span><span class="p">][</span><span class="n">j</span><span class="p">],</span> <span class="n">cmap</span><span class="o">=</span><span class="s2">&quot;gray&quot;</span><span class="p">)</span> <span class="n">axarr</span><span class="p">[</span><span class="n">i</span><span class="p">,</span> <span class="n">j</span><span class="p">]</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">&quot;off&quot;</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots_adjust</span><span class="p">(</span><span class="n">wspace</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">hspace</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">left</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">right</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">bottom</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">top</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> <span class="c1"># Visualize montage of slices.</span> <span class="c1"># 4 rows and 10 columns for 100 slices of the CT scan.</span> <span class="n">plot_slices</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="mi">128</span><span class="p">,</span> <span class="n">image</span><span class="p">[:,</span> <span class="p">:,</span> <span class="p">:</span><span class="mi">40</span><span class="p">])</span> </code></pre></div> <p><img alt="png" src="/img/examples/vision/3D_image_classification/3D_image_classification_19_0.png" /></p> <hr /> <h2 id="define-a-3d-convolutional-neural-network">Define a 3D convolutional neural network</h2> <p>To make the model easier to understand, we structure it into blocks. The architecture of the 3D CNN used in this example is based on <a href="https://arxiv.org/abs/2007.13224">this paper</a>.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">get_model</span><span class="p">(</span><span class="n">width</span><span class="o">=</span><span class="mi">128</span><span class="p">,</span> <span class="n">height</span><span class="o">=</span><span class="mi">128</span><span class="p">,</span> <span class="n">depth</span><span class="o">=</span><span class="mi">64</span><span class="p">):</span> <span class="w"> </span><span class="sd">&quot;&quot;&quot;Build a 3D convolutional neural network model.&quot;&quot;&quot;</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">((</span><span class="n">width</span><span class="p">,</span> <span class="n">height</span><span class="p">,</span> <span class="n">depth</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv3D</span><span class="p">(</span><span class="n">filters</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">&quot;relu&quot;</span><span class="p">)(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">MaxPool3D</span><span class="p">(</span><span class="n">pool_size</span><span class="o">=</span><span class="mi">2</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">BatchNormalization</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv3D</span><span class="p">(</span><span class="n">filters</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">&quot;relu&quot;</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">MaxPool3D</span><span class="p">(</span><span class="n">pool_size</span><span class="o">=</span><span class="mi">2</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">BatchNormalization</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv3D</span><span class="p">(</span><span class="n">filters</span><span class="o">=</span><span class="mi">128</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">&quot;relu&quot;</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">MaxPool3D</span><span class="p">(</span><span class="n">pool_size</span><span class="o">=</span><span class="mi">2</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">BatchNormalization</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv3D</span><span class="p">(</span><span class="n">filters</span><span class="o">=</span><span class="mi">256</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">&quot;relu&quot;</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">MaxPool3D</span><span class="p">(</span><span class="n">pool_size</span><span class="o">=</span><span class="mi">2</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">BatchNormalization</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">GlobalAveragePooling3D</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">units</span><span class="o">=</span><span class="mi">512</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">&quot;relu&quot;</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="mf">0.3</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">units</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">&quot;sigmoid&quot;</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># Define the model.</span> <span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;3dcnn&quot;</span><span class="p">)</span> <span class="k">return</span> <span class="n">model</span> <span class="c1"># Build model.</span> <span class="n">model</span> <span class="o">=</span> <span class="n">get_model</span><span class="p">(</span><span class="n">width</span><span class="o">=</span><span class="mi">128</span><span class="p">,</span> <span class="n">height</span><span class="o">=</span><span class="mi">128</span><span class="p">,</span> <span class="n">depth</span><span class="o">=</span><span class="mi">64</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">summary</span><span class="p">()</span> </code></pre></div> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold">Model: "3dcnn"</span> </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ ┃<span style="font-weight: bold"> Layer (type) </span>┃<span style="font-weight: bold"> Output Shape </span>┃<span style="font-weight: bold"> Param # </span>┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ │ input_layer (<span style="color: #0087ff; text-decoration-color: #0087ff">InputLayer</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>, <span style="color: #00af00; text-decoration-color: #00af00">1</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv3d (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv3D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">126</span>, <span style="color: #00af00; text-decoration-color: #00af00">126</span>, <span style="color: #00af00; text-decoration-color: #00af00">62</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">1,792</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ max_pooling3d (<span style="color: #0087ff; text-decoration-color: #0087ff">MaxPooling3D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">63</span>, <span style="color: #00af00; text-decoration-color: #00af00">63</span>, <span style="color: #00af00; text-decoration-color: #00af00">31</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ batch_normalization │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">63</span>, <span style="color: #00af00; text-decoration-color: #00af00">63</span>, <span style="color: #00af00; text-decoration-color: #00af00">31</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">256</span> │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">BatchNormalization</span>) │ │ │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv3d_1 (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv3D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">61</span>, <span style="color: #00af00; text-decoration-color: #00af00">61</span>, <span style="color: #00af00; text-decoration-color: #00af00">29</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">110,656</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ max_pooling3d_1 (<span style="color: #0087ff; text-decoration-color: #0087ff">MaxPooling3D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">30</span>, <span style="color: #00af00; text-decoration-color: #00af00">30</span>, <span style="color: #00af00; text-decoration-color: #00af00">14</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ batch_normalization_1 │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">30</span>, <span style="color: #00af00; text-decoration-color: #00af00">30</span>, <span style="color: #00af00; text-decoration-color: #00af00">14</span>, <span style="color: #00af00; text-decoration-color: #00af00">64</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">256</span> │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">BatchNormalization</span>) │ │ │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv3d_2 (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv3D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">28</span>, <span style="color: #00af00; text-decoration-color: #00af00">28</span>, <span style="color: #00af00; text-decoration-color: #00af00">12</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">221,312</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ max_pooling3d_2 (<span style="color: #0087ff; text-decoration-color: #0087ff">MaxPooling3D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">14</span>, <span style="color: #00af00; text-decoration-color: #00af00">14</span>, <span style="color: #00af00; text-decoration-color: #00af00">6</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ batch_normalization_2 │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">14</span>, <span style="color: #00af00; text-decoration-color: #00af00">14</span>, <span style="color: #00af00; text-decoration-color: #00af00">6</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">512</span> │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">BatchNormalization</span>) │ │ │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv3d_3 (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv3D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">12</span>, <span style="color: #00af00; text-decoration-color: #00af00">12</span>, <span style="color: #00af00; text-decoration-color: #00af00">4</span>, <span style="color: #00af00; text-decoration-color: #00af00">256</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">884,992</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ max_pooling3d_3 (<span style="color: #0087ff; text-decoration-color: #0087ff">MaxPooling3D</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">6</span>, <span style="color: #00af00; text-decoration-color: #00af00">6</span>, <span style="color: #00af00; text-decoration-color: #00af00">2</span>, <span style="color: #00af00; text-decoration-color: #00af00">256</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ batch_normalization_3 │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">6</span>, <span style="color: #00af00; text-decoration-color: #00af00">6</span>, <span style="color: #00af00; text-decoration-color: #00af00">2</span>, <span style="color: #00af00; text-decoration-color: #00af00">256</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">1,024</span> │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">BatchNormalization</span>) │ │ │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ global_average_pooling3d │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">256</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ │ (<span style="color: #0087ff; text-decoration-color: #0087ff">GlobalAveragePooling3D</span>) │ │ │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ dense (<span style="color: #0087ff; text-decoration-color: #0087ff">Dense</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">512</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">131,584</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ dropout (<span style="color: #0087ff; text-decoration-color: #0087ff">Dropout</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">512</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ dense_1 (<span style="color: #0087ff; text-decoration-color: #0087ff">Dense</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">1</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">513</span> │ └─────────────────────────────────┴───────────────────────────┴────────────┘ </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Total params: </span><span style="color: #00af00; text-decoration-color: #00af00">1,352,897</span> (5.16 MB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">1,351,873</span> (5.16 MB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Non-trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">1,024</span> (4.00 KB) </pre> <hr /> <h2 id="train-model">Train model</h2> <div class="codehilite"><pre><span></span><code><span class="c1"># Compile model.</span> <span class="n">initial_learning_rate</span> <span class="o">=</span> <span class="mf">0.0001</span> <span class="n">lr_schedule</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">schedules</span><span class="o">.</span><span class="n">ExponentialDecay</span><span class="p">(</span> <span class="n">initial_learning_rate</span><span class="p">,</span> <span class="n">decay_steps</span><span class="o">=</span><span class="mi">100000</span><span class="p">,</span> <span class="n">decay_rate</span><span class="o">=</span><span class="mf">0.96</span><span class="p">,</span> <span class="n">staircase</span><span class="o">=</span><span class="kc">True</span> <span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">loss</span><span class="o">=</span><span class="s2">&quot;binary_crossentropy&quot;</span><span class="p">,</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">learning_rate</span><span class="o">=</span><span class="n">lr_schedule</span><span class="p">),</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;acc&quot;</span><span class="p">],</span> <span class="n">run_eagerly</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="p">)</span> <span class="c1"># Define callbacks.</span> <span class="n">checkpoint_cb</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">callbacks</span><span class="o">.</span><span class="n">ModelCheckpoint</span><span class="p">(</span> <span class="s2">&quot;3d_image_classification.keras&quot;</span><span class="p">,</span> <span class="n">save_best_only</span><span class="o">=</span><span class="kc">True</span> <span class="p">)</span> <span class="n">early_stopping_cb</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">callbacks</span><span class="o">.</span><span class="n">EarlyStopping</span><span class="p">(</span><span class="n">monitor</span><span class="o">=</span><span class="s2">&quot;val_acc&quot;</span><span class="p">,</span> <span class="n">patience</span><span class="o">=</span><span class="mi">15</span><span class="p">)</span> <span class="c1"># Train the model, doing validation at the end of each epoch</span> <span class="n">epochs</span> <span class="o">=</span> <span class="mi">100</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span> <span class="n">train_dataset</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="n">validation_dataset</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="n">epochs</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">verbose</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">callbacks</span><span class="o">=</span><span class="p">[</span><span class="n">checkpoint_cb</span><span class="p">,</span> <span class="n">early_stopping_cb</span><span class="p">],</span> <span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 1/100 70/70 - 40s - 568ms/step - acc: 0.5786 - loss: 0.7128 - val_acc: 0.5000 - val_loss: 0.8744 Epoch 2/100 70/70 - 26s - 370ms/step - acc: 0.6000 - loss: 0.6760 - val_acc: 0.5000 - val_loss: 1.2741 Epoch 3/100 70/70 - 26s - 373ms/step - acc: 0.5643 - loss: 0.6768 - val_acc: 0.5000 - val_loss: 1.4767 Epoch 4/100 70/70 - 26s - 376ms/step - acc: 0.6643 - loss: 0.6671 - val_acc: 0.5000 - val_loss: 1.2609 Epoch 5/100 70/70 - 26s - 374ms/step - acc: 0.6714 - loss: 0.6274 - val_acc: 0.5667 - val_loss: 0.6470 Epoch 6/100 70/70 - 26s - 372ms/step - acc: 0.5929 - loss: 0.6492 - val_acc: 0.6667 - val_loss: 0.6022 Epoch 7/100 70/70 - 26s - 374ms/step - acc: 0.5929 - loss: 0.6601 - val_acc: 0.5667 - val_loss: 0.6788 Epoch 8/100 70/70 - 26s - 378ms/step - acc: 0.6000 - loss: 0.6559 - val_acc: 0.6667 - val_loss: 0.6090 Epoch 9/100 70/70 - 26s - 373ms/step - acc: 0.6357 - loss: 0.6423 - val_acc: 0.6000 - val_loss: 0.6535 Epoch 10/100 70/70 - 26s - 374ms/step - acc: 0.6500 - loss: 0.6127 - val_acc: 0.6500 - val_loss: 0.6204 Epoch 11/100 70/70 - 26s - 374ms/step - acc: 0.6714 - loss: 0.5994 - val_acc: 0.7000 - val_loss: 0.6218 Epoch 12/100 70/70 - 26s - 374ms/step - acc: 0.6714 - loss: 0.5980 - val_acc: 0.7167 - val_loss: 0.5069 Epoch 13/100 70/70 - 26s - 369ms/step - acc: 0.7214 - loss: 0.6003 - val_acc: 0.7833 - val_loss: 0.5182 Epoch 14/100 70/70 - 26s - 372ms/step - acc: 0.6643 - loss: 0.6076 - val_acc: 0.7167 - val_loss: 0.5613 Epoch 15/100 70/70 - 26s - 373ms/step - acc: 0.6571 - loss: 0.6359 - val_acc: 0.6167 - val_loss: 0.6184 Epoch 16/100 70/70 - 26s - 374ms/step - acc: 0.6429 - loss: 0.6053 - val_acc: 0.7167 - val_loss: 0.5258 Epoch 17/100 70/70 - 26s - 370ms/step - acc: 0.6786 - loss: 0.6119 - val_acc: 0.5667 - val_loss: 0.8481 Epoch 18/100 70/70 - 26s - 372ms/step - acc: 0.6286 - loss: 0.6298 - val_acc: 0.6667 - val_loss: 0.5709 Epoch 19/100 70/70 - 26s - 372ms/step - acc: 0.7214 - loss: 0.5979 - val_acc: 0.5833 - val_loss: 0.6730 Epoch 20/100 70/70 - 26s - 372ms/step - acc: 0.7571 - loss: 0.5224 - val_acc: 0.7167 - val_loss: 0.5710 Epoch 21/100 70/70 - 26s - 372ms/step - acc: 0.7357 - loss: 0.5606 - val_acc: 0.7167 - val_loss: 0.5444 Epoch 22/100 70/70 - 26s - 372ms/step - acc: 0.7357 - loss: 0.5334 - val_acc: 0.5667 - val_loss: 0.7919 Epoch 23/100 70/70 - 26s - 373ms/step - acc: 0.7071 - loss: 0.5337 - val_acc: 0.5167 - val_loss: 0.9527 Epoch 24/100 70/70 - 26s - 371ms/step - acc: 0.7071 - loss: 0.5635 - val_acc: 0.7167 - val_loss: 0.5333 Epoch 25/100 70/70 - 26s - 373ms/step - acc: 0.7643 - loss: 0.4787 - val_acc: 0.6333 - val_loss: 1.0172 Epoch 26/100 70/70 - 26s - 372ms/step - acc: 0.7357 - loss: 0.5535 - val_acc: 0.6500 - val_loss: 0.6926 Epoch 27/100 70/70 - 26s - 370ms/step - acc: 0.7286 - loss: 0.5608 - val_acc: 0.5000 - val_loss: 3.3032 Epoch 28/100 70/70 - 26s - 370ms/step - acc: 0.7429 - loss: 0.5436 - val_acc: 0.6500 - val_loss: 0.6438 &lt;keras.src.callbacks.history.History at 0x7fc5b923e810&gt; </code></pre></div> </div> <p>It is important to note that the number of samples is very small (only 200) and we don't specify a random seed. As such, you can expect significant variance in the results. The full dataset which consists of over 1000 CT scans can be found <a href="https://www.medrxiv.org/content/10.1101/2020.05.20.20100362v1">here</a>. Using the full dataset, an accuracy of 83% was achieved. A variability of 6-7% in the classification performance is observed in both cases.</p> <hr /> <h2 id="visualizing-model-performance">Visualizing model performance</h2> <p>Here the model accuracy and loss for the training and the validation sets are plotted. Since the validation set is class-balanced, accuracy provides an unbiased representation of the model's performance.</p> <div class="codehilite"><pre><span></span><code><span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">20</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">ax</span><span class="o">.</span><span class="n">ravel</span><span class="p">()</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">metric</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">([</span><span class="s2">&quot;acc&quot;</span><span class="p">,</span> <span class="s2">&quot;loss&quot;</span><span class="p">]):</span> <span class="n">ax</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">history</span><span class="o">.</span><span class="n">history</span><span class="p">[</span><span class="n">metric</span><span class="p">])</span> <span class="n">ax</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">history</span><span class="o">.</span><span class="n">history</span><span class="p">[</span><span class="s2">&quot;val_&quot;</span> <span class="o">+</span> <span class="n">metric</span><span class="p">])</span> <span class="n">ax</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">&quot;Model </span><span class="si">{}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">metric</span><span class="p">))</span> <span class="n">ax</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s2">&quot;epochs&quot;</span><span class="p">)</span> <span class="n">ax</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="n">metric</span><span class="p">)</span> <span class="n">ax</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">legend</span><span class="p">([</span><span class="s2">&quot;train&quot;</span><span class="p">,</span> <span class="s2">&quot;val&quot;</span><span class="p">])</span> </code></pre></div> <p><img alt="png" src="/img/examples/vision/3D_image_classification/3D_image_classification_26_0.png" /></p> <hr /> <h2 id="make-predictions-on-a-single-ct-scan">Make predictions on a single CT scan</h2> <div class="codehilite"><pre><span></span><code><span class="c1"># Load best weights.</span> <span class="n">model</span><span class="o">.</span><span class="n">load_weights</span><span class="p">(</span><span class="s2">&quot;3d_image_classification.keras&quot;</span><span class="p">)</span> <span class="n">prediction</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">x_val</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">))[</span><span class="mi">0</span><span class="p">]</span> <span class="n">scores</span> <span class="o">=</span> <span class="p">[</span><span class="mi">1</span> <span class="o">-</span> <span class="n">prediction</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">prediction</span><span class="p">[</span><span class="mi">0</span><span class="p">]]</span> <span class="n">class_names</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;normal&quot;</span><span class="p">,</span> <span class="s2">&quot;abnormal&quot;</span><span class="p">]</span> <span class="k">for</span> <span class="n">score</span><span class="p">,</span> <span class="n">name</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">scores</span><span class="p">,</span> <span class="n">class_names</span><span class="p">):</span> <span class="nb">print</span><span class="p">(</span> <span class="s2">&quot;This model is </span><span class="si">%.2f</span><span class="s2"> percent confident that CT scan is </span><span class="si">%s</span><span class="s2">&quot;</span> <span class="o">%</span> <span class="p">((</span><span class="mi">100</span> <span class="o">*</span> <span class="n">score</span><span class="p">),</span> <span class="n">name</span><span class="p">)</span> <span class="p">)</span> </code></pre></div> <p>1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 478ms/step</p> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> </code></pre></div> </div> <p>1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 479ms/step</p> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>This model is 32.99 percent confident that CT scan is normal This model is 67.01 percent confident that CT scan is abnormal </code></pre></div> </div> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#3d-image-classification-from-ct-scans'>3D image classification from CT scans</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#references'>References</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#setup'>Setup</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#downloading-the-mosmeddata-chest-ct-scans-with-covid19-related-findings'>Downloading the MosMedData: Chest CT Scans with COVID-19 Related Findings</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#loading-data-and-preprocessing'>Loading data and preprocessing</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#build-train-and-validation-datasets'>Build train and validation datasets</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#data-augmentation'>Data augmentation</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#define-a-3d-convolutional-neural-network'>Define a 3D convolutional neural network</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#train-model'>Train model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#visualizing-model-performance'>Visualizing model performance</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#make-predictions-on-a-single-ct-scan'>Make predictions on a single CT scan</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>

Pages: 1 2 3 4 5 6 7 8 9 10