CINXE.COM

Pneumonia Classification on TPU

<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/vision/xray_classification_with_tpus/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Pneumonia Classification on TPU"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Pneumonia Classification on TPU"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Pneumonia Classification on TPU</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink active" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink2" href="/examples/vision/image_classification_from_scratch/">Image classification from scratch</a> <a class="nav-sublink2" href="/examples/vision/mnist_convnet/">Simple MNIST convnet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_efficientnet_fine_tuning/">Image classification via fine-tuning with EfficientNet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_with_vision_transformer/">Image classification with Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/attention_mil_classification/">Classification using Attention-based Deep Multiple Instance Learning</a> <a class="nav-sublink2" href="/examples/vision/mlp_image_classification/">Image classification with modern MLP models</a> <a class="nav-sublink2" href="/examples/vision/mobilevit/">A mobile-friendly Transformer-based model for image classification</a> <a class="nav-sublink2 active" href="/examples/vision/xray_classification_with_tpus/">Pneumonia Classification on TPU</a> <a class="nav-sublink2" href="/examples/vision/cct/">Compact Convolutional Transformers</a> <a class="nav-sublink2" href="/examples/vision/convmixer/">Image classification with ConvMixer</a> <a class="nav-sublink2" href="/examples/vision/eanet/">Image classification with EANet (External Attention Transformer)</a> <a class="nav-sublink2" href="/examples/vision/involution/">Involutional neural networks</a> <a class="nav-sublink2" href="/examples/vision/perceiver_image_classification/">Image classification with Perceiver</a> <a class="nav-sublink2" href="/examples/vision/reptile/">Few-Shot learning with Reptile</a> <a class="nav-sublink2" href="/examples/vision/semisupervised_simclr/">Semi-supervised image classification using contrastive pretraining with SimCLR</a> <a class="nav-sublink2" href="/examples/vision/swin_transformers/">Image classification with Swin Transformers</a> <a class="nav-sublink2" href="/examples/vision/vit_small_ds/">Train a Vision Transformer on small datasets</a> <a class="nav-sublink2" href="/examples/vision/shiftvit/">A Vision Transformer without Attention</a> <a class="nav-sublink2" href="/examples/vision/image_classification_using_global_context_vision_transformer/">Image Classification using Global Context Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/oxford_pets_image_segmentation/">Image segmentation with a U-Net-like architecture</a> <a class="nav-sublink2" href="/examples/vision/deeplabv3_plus/">Multiclass semantic segmentation using DeepLabV3+</a> <a class="nav-sublink2" href="/examples/vision/basnet_segmentation/">Highly accurate boundaries segmentation using BASNet</a> <a class="nav-sublink2" href="/examples/vision/fully_convolutional_network/">Image Segmentation using Composable Fully-Convolutional Networks</a> <a class="nav-sublink2" href="/examples/vision/retinanet/">Object Detection with RetinaNet</a> <a class="nav-sublink2" href="/examples/vision/keypoint_detection/">Keypoint Detection with Transfer Learning</a> <a class="nav-sublink2" href="/examples/vision/object_detection_using_vision_transformer/">Object detection with Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/3D_image_classification/">3D image classification from CT scans</a> <a class="nav-sublink2" href="/examples/vision/depth_estimation/">Monocular depth estimation</a> <a class="nav-sublink2" href="/examples/vision/nerf/">3D volumetric rendering with NeRF</a> <a class="nav-sublink2" href="/examples/vision/pointnet_segmentation/">Point cloud segmentation with PointNet</a> <a class="nav-sublink2" href="/examples/vision/pointnet/">Point cloud classification</a> <a class="nav-sublink2" href="/examples/vision/captcha_ocr/">OCR model for reading Captchas</a> <a class="nav-sublink2" href="/examples/vision/handwriting_recognition/">Handwriting recognition</a> <a class="nav-sublink2" href="/examples/vision/autoencoder/">Convolutional autoencoder for image denoising</a> <a class="nav-sublink2" href="/examples/vision/mirnet/">Low-light image enhancement using MIRNet</a> <a class="nav-sublink2" href="/examples/vision/super_resolution_sub_pixel/">Image Super-Resolution using an Efficient Sub-Pixel CNN</a> <a class="nav-sublink2" href="/examples/vision/edsr/">Enhanced Deep Residual Networks for single-image super-resolution</a> <a class="nav-sublink2" href="/examples/vision/zero_dce/">Zero-DCE for low-light image enhancement</a> <a class="nav-sublink2" href="/examples/vision/cutmix/">CutMix data augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/mixup/">MixUp augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/randaugment/">RandAugment for Image Classification for Improved Robustness</a> <a class="nav-sublink2" href="/examples/vision/image_captioning/">Image captioning</a> <a class="nav-sublink2" href="/examples/vision/nl_image_search/">Natural language image search with a Dual Encoder</a> <a class="nav-sublink2" href="/examples/vision/visualizing_what_convnets_learn/">Visualizing what convnets learn</a> <a class="nav-sublink2" href="/examples/vision/integrated_gradients/">Model interpretability with Integrated Gradients</a> <a class="nav-sublink2" href="/examples/vision/probing_vits/">Investigating Vision Transformer representations</a> <a class="nav-sublink2" href="/examples/vision/grad_cam/">Grad-CAM class activation visualization</a> <a class="nav-sublink2" href="/examples/vision/near_dup_search/">Near-duplicate image search</a> <a class="nav-sublink2" href="/examples/vision/semantic_image_clustering/">Semantic Image Clustering</a> <a class="nav-sublink2" href="/examples/vision/siamese_contrastive/">Image similarity estimation using a Siamese Network with a contrastive loss</a> <a class="nav-sublink2" href="/examples/vision/siamese_network/">Image similarity estimation using a Siamese Network with a triplet loss</a> <a class="nav-sublink2" href="/examples/vision/metric_learning/">Metric learning for image similarity search</a> <a class="nav-sublink2" href="/examples/vision/metric_learning_tf_similarity/">Metric learning for image similarity search using TensorFlow Similarity</a> <a class="nav-sublink2" href="/examples/vision/nnclr/">Self-supervised contrastive learning with NNCLR</a> <a class="nav-sublink2" href="/examples/vision/video_classification/">Video Classification with a CNN-RNN Architecture</a> <a class="nav-sublink2" href="/examples/vision/conv_lstm/">Next-Frame Video Prediction with Convolutional LSTMs</a> <a class="nav-sublink2" href="/examples/vision/video_transformers/">Video Classification with Transformers</a> <a class="nav-sublink2" href="/examples/vision/vivit/">Video Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/bit/">Image Classification using BigTransfer (BiT)</a> <a class="nav-sublink2" href="/examples/vision/gradient_centralization/">Gradient Centralization for Better Training Performance</a> <a class="nav-sublink2" href="/examples/vision/token_learner/">Learning to tokenize in Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/knowledge_distillation/">Knowledge Distillation</a> <a class="nav-sublink2" href="/examples/vision/fixres/">FixRes: Fixing train-test resolution discrepancy</a> <a class="nav-sublink2" href="/examples/vision/cait/">Class Attention Image Transformers with LayerScale</a> <a class="nav-sublink2" href="/examples/vision/patch_convnet/">Augmenting convnets with aggregated attention</a> <a class="nav-sublink2" href="/examples/vision/learnable_resizer/">Learning to Resize</a> <a class="nav-sublink2" href="/examples/vision/adamatch/">Semi-supervision and domain adaptation with AdaMatch</a> <a class="nav-sublink2" href="/examples/vision/barlow_twins/">Barlow Twins for Contrastive SSL</a> <a class="nav-sublink2" href="/examples/vision/consistency_training/">Consistency training with supervision</a> <a class="nav-sublink2" href="/examples/vision/deit/">Distilling Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/focal_modulation_network/">Focal Modulation: A replacement for Self-Attention</a> <a class="nav-sublink2" href="/examples/vision/forwardforward/">Using the Forward-Forward Algorithm for Image Classification</a> <a class="nav-sublink2" href="/examples/vision/masked_image_modeling/">Masked image modeling with Autoencoders</a> <a class="nav-sublink2" href="/examples/vision/sam/">Segment Anything Model with 🤗Transformers</a> <a class="nav-sublink2" href="/examples/vision/segformer/">Semantic segmentation with SegFormer and Hugging Face Transformers</a> <a class="nav-sublink2" href="/examples/vision/simsiam/">Self-supervised contrastive learning with SimSiam</a> <a class="nav-sublink2" href="/examples/vision/supervised-contrastive-learning/">Supervised Contrastive Learning</a> <a class="nav-sublink2" href="/examples/vision/temporal_latent_bottleneck/">When Recurrence meets Transformers</a> <a class="nav-sublink2" href="/examples/vision/yolov8/">Efficient Object Detection with YOLOV8 and KerasCV</a> <a class="nav-sublink" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparameter Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> <a class="nav-link" href="/keras_cv/" role="tab" aria-selected="">KerasCV: Computer Vision Workflows</a> <a class="nav-link" href="/keras_nlp/" role="tab" aria-selected="">KerasNLP: Natural Language Workflows</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/vision/'>Computer Vision</a> / Pneumonia Classification on TPU </div> <div class='k-content'> <h1 id="pneumonia-classification-on-tpu">Pneumonia Classification on TPU</h1> <p><strong>Author:</strong> Amy MiHyun Jang<br> <strong>Date created:</strong> 2020/07/28<br> <strong>Last modified:</strong> 2024/02/12<br> <strong>Description:</strong> Medical image classification on TPU.</p> <div class='example_version_banner keras_3'>ⓘ This example uses Keras 3</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/vision/ipynb/xray_classification_with_tpus.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/examples/vision/xray_classification_with_tpus.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="introduction--setup">Introduction + Set-up</h2> <p>This tutorial will explain how to build an X-ray image classification model to predict whether an X-ray scan shows presence of pneumonia.</p> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">re</span> <span class="kn">import</span> <span class="nn">os</span> <span class="kn">import</span> <span class="nn">random</span> <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> <span class="kn">import</span> <span class="nn">pandas</span> <span class="k">as</span> <span class="nn">pd</span> <span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="nn">tf</span> <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span> <span class="k">try</span><span class="p">:</span> <span class="n">tpu</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">distribute</span><span class="o">.</span><span class="n">cluster_resolver</span><span class="o">.</span><span class="n">TPUClusterResolver</span><span class="o">.</span><span class="n">connect</span><span class="p">()</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Device:&quot;</span><span class="p">,</span> <span class="n">tpu</span><span class="o">.</span><span class="n">master</span><span class="p">())</span> <span class="n">strategy</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">distribute</span><span class="o">.</span><span class="n">TPUStrategy</span><span class="p">(</span><span class="n">tpu</span><span class="p">)</span> <span class="k">except</span><span class="p">:</span> <span class="n">strategy</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">distribute</span><span class="o">.</span><span class="n">get_strategy</span><span class="p">()</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Number of replicas:&quot;</span><span class="p">,</span> <span class="n">strategy</span><span class="o">.</span><span class="n">num_replicas_in_sync</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Device: grpc://10.0.27.122:8470 INFO:tensorflow:Initializing the TPU system: grpc://10.0.27.122:8470 INFO:tensorflow:Initializing the TPU system: grpc://10.0.27.122:8470 INFO:tensorflow:Clearing out eager caches INFO:tensorflow:Clearing out eager caches INFO:tensorflow:Finished initializing TPU system. INFO:tensorflow:Finished initializing TPU system. WARNING:absl:[`tf.distribute.TPUStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/TPUStrategy) is deprecated, please use the non experimental symbol [`tf.distribute.TPUStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/TPUStrategy) instead. INFO:tensorflow:Found TPU system: INFO:tensorflow:Found TPU system: INFO:tensorflow:*** Num TPU Cores: 8 INFO:tensorflow:*** Num TPU Cores: 8 INFO:tensorflow:*** Num TPU Workers: 1 INFO:tensorflow:*** Num TPU Workers: 1 INFO:tensorflow:*** Num TPU Cores Per Worker: 8 INFO:tensorflow:*** Num TPU Cores Per Worker: 8 INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0) INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0) Number of replicas: 8 </code></pre></div> </div> <p>We need a Google Cloud link to our data to load the data using a TPU. Below, we define key configuration parameters we'll use in this example. To run on TPU, this example must be on Colab with the TPU runtime selected.</p> <div class="codehilite"><pre><span></span><code><span class="n">AUTOTUNE</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">AUTOTUNE</span> <span class="n">BATCH_SIZE</span> <span class="o">=</span> <span class="mi">25</span> <span class="o">*</span> <span class="n">strategy</span><span class="o">.</span><span class="n">num_replicas_in_sync</span> <span class="n">IMAGE_SIZE</span> <span class="o">=</span> <span class="p">[</span><span class="mi">180</span><span class="p">,</span> <span class="mi">180</span><span class="p">]</span> <span class="n">CLASS_NAMES</span> <span class="o">=</span> <span class="p">[</span><span class="s2">&quot;NORMAL&quot;</span><span class="p">,</span> <span class="s2">&quot;PNEUMONIA&quot;</span><span class="p">]</span> </code></pre></div> <hr /> <h2 id="load-the-data">Load the data</h2> <p>The Chest X-ray data we are using from <a href="https://www.cell.com/cell/fulltext/S0092-8674(18)30154-5"><em>Cell</em></a> divides the data into training and test files. Let's first load in the training TFRecords.</p> <div class="codehilite"><pre><span></span><code><span class="n">train_images</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">TFRecordDataset</span><span class="p">(</span> <span class="s2">&quot;gs://download.tensorflow.org/data/ChestXRay2017/train/images.tfrec&quot;</span> <span class="p">)</span> <span class="n">train_paths</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">TFRecordDataset</span><span class="p">(</span> <span class="s2">&quot;gs://download.tensorflow.org/data/ChestXRay2017/train/paths.tfrec&quot;</span> <span class="p">)</span> <span class="n">ds</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">zip</span><span class="p">((</span><span class="n">train_images</span><span class="p">,</span> <span class="n">train_paths</span><span class="p">))</span> </code></pre></div> <p>Let's count how many healthy/normal chest X-rays we have and how many pneumonia chest X-rays we have:</p> <div class="codehilite"><pre><span></span><code><span class="n">COUNT_NORMAL</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span> <span class="p">[</span> <span class="n">filename</span> <span class="k">for</span> <span class="n">filename</span> <span class="ow">in</span> <span class="n">train_paths</span> <span class="k">if</span> <span class="s2">&quot;NORMAL&quot;</span> <span class="ow">in</span> <span class="n">filename</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span><span class="o">.</span><span class="n">decode</span><span class="p">(</span><span class="s2">&quot;utf-8&quot;</span><span class="p">)</span> <span class="p">]</span> <span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Normal images count in training set: &quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">COUNT_NORMAL</span><span class="p">))</span> <span class="n">COUNT_PNEUMONIA</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span> <span class="p">[</span> <span class="n">filename</span> <span class="k">for</span> <span class="n">filename</span> <span class="ow">in</span> <span class="n">train_paths</span> <span class="k">if</span> <span class="s2">&quot;PNEUMONIA&quot;</span> <span class="ow">in</span> <span class="n">filename</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span><span class="o">.</span><span class="n">decode</span><span class="p">(</span><span class="s2">&quot;utf-8&quot;</span><span class="p">)</span> <span class="p">]</span> <span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Pneumonia images count in training set: &quot;</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">COUNT_PNEUMONIA</span><span class="p">))</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Normal images count in training set: 1349 Pneumonia images count in training set: 3883 </code></pre></div> </div> <p>Notice that there are way more images that are classified as pneumonia than normal. This shows that we have an imbalance in our data. We will correct for this imbalance later on in our notebook.</p> <p>We want to map each filename to the corresponding (image, label) pair. The following methods will help us do that.</p> <p>As we only have two labels, we will encode the label so that <code>1</code> or <code>True</code> indicates pneumonia and <code>0</code> or <code>False</code> indicates normal.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">get_label</span><span class="p">(</span><span class="n">file_path</span><span class="p">):</span> <span class="c1"># convert the path to a list of path components</span> <span class="n">parts</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">strings</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">file_path</span><span class="p">,</span> <span class="s2">&quot;/&quot;</span><span class="p">)</span> <span class="c1"># The second to last is the class-directory</span> <span class="k">if</span> <span class="n">parts</span><span class="p">[</span><span class="o">-</span><span class="mi">2</span><span class="p">]</span> <span class="o">==</span> <span class="s2">&quot;PNEUMONIA&quot;</span><span class="p">:</span> <span class="k">return</span> <span class="mi">1</span> <span class="k">else</span><span class="p">:</span> <span class="k">return</span> <span class="mi">0</span> <span class="k">def</span> <span class="nf">decode_img</span><span class="p">(</span><span class="n">img</span><span class="p">):</span> <span class="c1"># convert the compressed string to a 3D uint8 tensor</span> <span class="n">img</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">decode_jpeg</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">channels</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span> <span class="c1"># resize the image to the desired size.</span> <span class="k">return</span> <span class="n">tf</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">resize</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">IMAGE_SIZE</span><span class="p">)</span> <span class="k">def</span> <span class="nf">process_path</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="n">path</span><span class="p">):</span> <span class="n">label</span> <span class="o">=</span> <span class="n">get_label</span><span class="p">(</span><span class="n">path</span><span class="p">)</span> <span class="c1"># load the raw data from the file as a string</span> <span class="n">img</span> <span class="o">=</span> <span class="n">decode_img</span><span class="p">(</span><span class="n">image</span><span class="p">)</span> <span class="k">return</span> <span class="n">img</span><span class="p">,</span> <span class="n">label</span> <span class="n">ds</span> <span class="o">=</span> <span class="n">ds</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">process_path</span><span class="p">,</span> <span class="n">num_parallel_calls</span><span class="o">=</span><span class="n">AUTOTUNE</span><span class="p">)</span> </code></pre></div> <p>Let's split the data into a training and validation datasets.</p> <div class="codehilite"><pre><span></span><code><span class="n">ds</span> <span class="o">=</span> <span class="n">ds</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="mi">10000</span><span class="p">)</span> <span class="n">train_ds</span> <span class="o">=</span> <span class="n">ds</span><span class="o">.</span><span class="n">take</span><span class="p">(</span><span class="mi">4200</span><span class="p">)</span> <span class="n">val_ds</span> <span class="o">=</span> <span class="n">ds</span><span class="o">.</span><span class="n">skip</span><span class="p">(</span><span class="mi">4200</span><span class="p">)</span> </code></pre></div> <p>Let's visualize the shape of an (image, label) pair.</p> <div class="codehilite"><pre><span></span><code><span class="k">for</span> <span class="n">image</span><span class="p">,</span> <span class="n">label</span> <span class="ow">in</span> <span class="n">train_ds</span><span class="o">.</span><span class="n">take</span><span class="p">(</span><span class="mi">1</span><span class="p">):</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Image shape: &quot;</span><span class="p">,</span> <span class="n">image</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Label: &quot;</span><span class="p">,</span> <span class="n">label</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Image shape: (180, 180, 3) Label: False </code></pre></div> </div> <p>Load and format the test data as well.</p> <div class="codehilite"><pre><span></span><code><span class="n">test_images</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">TFRecordDataset</span><span class="p">(</span> <span class="s2">&quot;gs://download.tensorflow.org/data/ChestXRay2017/test/images.tfrec&quot;</span> <span class="p">)</span> <span class="n">test_paths</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">TFRecordDataset</span><span class="p">(</span> <span class="s2">&quot;gs://download.tensorflow.org/data/ChestXRay2017/test/paths.tfrec&quot;</span> <span class="p">)</span> <span class="n">test_ds</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">zip</span><span class="p">((</span><span class="n">test_images</span><span class="p">,</span> <span class="n">test_paths</span><span class="p">))</span> <span class="n">test_ds</span> <span class="o">=</span> <span class="n">test_ds</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">process_path</span><span class="p">,</span> <span class="n">num_parallel_calls</span><span class="o">=</span><span class="n">AUTOTUNE</span><span class="p">)</span> <span class="n">test_ds</span> <span class="o">=</span> <span class="n">test_ds</span><span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="n">BATCH_SIZE</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="visualize-the-dataset">Visualize the dataset</h2> <p>First, let's use buffered prefetching so we can yield data from disk without having I/O become blocking.</p> <p>Please note that large image datasets should not be cached in memory. We do it here because the dataset is not very large and we want to train on TPU.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">prepare_for_training</span><span class="p">(</span><span class="n">ds</span><span class="p">,</span> <span class="n">cache</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span> <span class="c1"># This is a small dataset, only load it once, and keep it in memory.</span> <span class="c1"># use `.cache(filename)` to cache preprocessing work for datasets that don&#39;t</span> <span class="c1"># fit in memory.</span> <span class="k">if</span> <span class="n">cache</span><span class="p">:</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">cache</span><span class="p">,</span> <span class="nb">str</span><span class="p">):</span> <span class="n">ds</span> <span class="o">=</span> <span class="n">ds</span><span class="o">.</span><span class="n">cache</span><span class="p">(</span><span class="n">cache</span><span class="p">)</span> <span class="k">else</span><span class="p">:</span> <span class="n">ds</span> <span class="o">=</span> <span class="n">ds</span><span class="o">.</span><span class="n">cache</span><span class="p">()</span> <span class="n">ds</span> <span class="o">=</span> <span class="n">ds</span><span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="n">BATCH_SIZE</span><span class="p">)</span> <span class="c1"># `prefetch` lets the dataset fetch batches in the background while the model</span> <span class="c1"># is training.</span> <span class="n">ds</span> <span class="o">=</span> <span class="n">ds</span><span class="o">.</span><span class="n">prefetch</span><span class="p">(</span><span class="n">buffer_size</span><span class="o">=</span><span class="n">AUTOTUNE</span><span class="p">)</span> <span class="k">return</span> <span class="n">ds</span> </code></pre></div> <p>Call the next batch iteration of the training data.</p> <div class="codehilite"><pre><span></span><code><span class="n">train_ds</span> <span class="o">=</span> <span class="n">prepare_for_training</span><span class="p">(</span><span class="n">train_ds</span><span class="p">)</span> <span class="n">val_ds</span> <span class="o">=</span> <span class="n">prepare_for_training</span><span class="p">(</span><span class="n">val_ds</span><span class="p">)</span> <span class="n">image_batch</span><span class="p">,</span> <span class="n">label_batch</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="nb">iter</span><span class="p">(</span><span class="n">train_ds</span><span class="p">))</span> </code></pre></div> <p>Define the method to show the images in the batch.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">show_batch</span><span class="p">(</span><span class="n">image_batch</span><span class="p">,</span> <span class="n">label_batch</span><span class="p">):</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span> <span class="k">for</span> <span class="n">n</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">25</span><span class="p">):</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="n">n</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">image_batch</span><span class="p">[</span><span class="n">n</span><span class="p">]</span> <span class="o">/</span> <span class="mi">255</span><span class="p">)</span> <span class="k">if</span> <span class="n">label_batch</span><span class="p">[</span><span class="n">n</span><span class="p">]:</span> <span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="s2">&quot;PNEUMONIA&quot;</span><span class="p">)</span> <span class="k">else</span><span class="p">:</span> <span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="s2">&quot;NORMAL&quot;</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">&quot;off&quot;</span><span class="p">)</span> </code></pre></div> <p>As the method takes in NumPy arrays as its parameters, call the numpy function on the batches to return the tensor in NumPy array form.</p> <div class="codehilite"><pre><span></span><code><span class="n">show_batch</span><span class="p">(</span><span class="n">image_batch</span><span class="o">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">label_batch</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span> </code></pre></div> <p><img alt="png" src="/img/examples/vision/xray_classification_with_tpus/xray_classification_with_tpus_25_0.png" /></p> <hr /> <h2 id="build-the-cnn">Build the CNN</h2> <p>To make our model more modular and easier to understand, let's define some blocks. As we're building a convolution neural network, we'll create a convolution block and a dense layer block.</p> <p>The architecture for this CNN has been inspired by this <a href="https://towardsdatascience.com/deep-learning-for-detecting-pneumonia-from-x-ray-images-fc9a3d9fdba8">article</a>.</p> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">os</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s1">&#39;KERAS_BACKEND&#39;</span><span class="p">]</span> <span class="o">=</span> <span class="s1">&#39;tensorflow&#39;</span> <span class="kn">import</span> <span class="nn">keras</span> <span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">layers</span> <span class="k">def</span> <span class="nf">conv_block</span><span class="p">(</span><span class="n">filters</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">SeparableConv2D</span><span class="p">(</span><span class="n">filters</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">&quot;relu&quot;</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s2">&quot;same&quot;</span><span class="p">)(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">SeparableConv2D</span><span class="p">(</span><span class="n">filters</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">&quot;relu&quot;</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s2">&quot;same&quot;</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">BatchNormalization</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">MaxPool2D</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">outputs</span> <span class="k">def</span> <span class="nf">dense_block</span><span class="p">(</span><span class="n">units</span><span class="p">,</span> <span class="n">dropout_rate</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">units</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">&quot;relu&quot;</span><span class="p">)(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">BatchNormalization</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout_rate</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">outputs</span> </code></pre></div> <p>The following method will define the function to build our model for us.</p> <p>The images originally have values that range from [0, 255]. CNNs work better with smaller numbers so we will scale this down for our input.</p> <p>The Dropout layers are important, as they reduce the likelikhood of the model overfitting. We want to end the model with a <code>Dense</code> layer with one node, as this will be the binary output that determines if an X-ray shows presence of pneumonia.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">build_model</span><span class="p">():</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">IMAGE_SIZE</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">IMAGE_SIZE</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="mi">3</span><span class="p">))</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Rescaling</span><span class="p">(</span><span class="mf">1.0</span> <span class="o">/</span> <span class="mi">255</span><span class="p">)(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">16</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">&quot;relu&quot;</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s2">&quot;same&quot;</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">16</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">&quot;relu&quot;</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s2">&quot;same&quot;</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">MaxPool2D</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">conv_block</span><span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">conv_block</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">conv_block</span><span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="mf">0.2</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">conv_block</span><span class="p">(</span><span class="mi">256</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="mf">0.2</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Flatten</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">dense_block</span><span class="p">(</span><span class="mi">512</span><span class="p">,</span> <span class="mf">0.7</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">dense_block</span><span class="p">(</span><span class="mi">128</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">dense_block</span><span class="p">(</span><span class="mi">64</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">&quot;sigmoid&quot;</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="o">=</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="o">=</span><span class="n">outputs</span><span class="p">)</span> <span class="k">return</span> <span class="n">model</span> </code></pre></div> <hr /> <h2 id="correct-for-data-imbalance">Correct for data imbalance</h2> <p>We saw earlier in this example that the data was imbalanced, with more images classified as pneumonia than normal. We will correct for that by using class weighting:</p> <div class="codehilite"><pre><span></span><code><span class="n">initial_bias</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">log</span><span class="p">([</span><span class="n">COUNT_PNEUMONIA</span> <span class="o">/</span> <span class="n">COUNT_NORMAL</span><span class="p">])</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Initial bias: </span><span class="si">{:.5f}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">initial_bias</span><span class="p">[</span><span class="mi">0</span><span class="p">]))</span> <span class="n">TRAIN_IMG_COUNT</span> <span class="o">=</span> <span class="n">COUNT_NORMAL</span> <span class="o">+</span> <span class="n">COUNT_PNEUMONIA</span> <span class="n">weight_for_0</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">/</span> <span class="n">COUNT_NORMAL</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">TRAIN_IMG_COUNT</span><span class="p">)</span> <span class="o">/</span> <span class="mf">2.0</span> <span class="n">weight_for_1</span> <span class="o">=</span> <span class="p">(</span><span class="mi">1</span> <span class="o">/</span> <span class="n">COUNT_PNEUMONIA</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">TRAIN_IMG_COUNT</span><span class="p">)</span> <span class="o">/</span> <span class="mf">2.0</span> <span class="n">class_weight</span> <span class="o">=</span> <span class="p">{</span><span class="mi">0</span><span class="p">:</span> <span class="n">weight_for_0</span><span class="p">,</span> <span class="mi">1</span><span class="p">:</span> <span class="n">weight_for_1</span><span class="p">}</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Weight for class 0: </span><span class="si">{:.2f}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">weight_for_0</span><span class="p">))</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Weight for class 1: </span><span class="si">{:.2f}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">weight_for_1</span><span class="p">))</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Initial bias: 1.05724 Weight for class 0: 1.94 Weight for class 1: 0.67 </code></pre></div> </div> <p>The weight for class <code>0</code> (Normal) is a lot higher than the weight for class <code>1</code> (Pneumonia). Because there are less normal images, each normal image will be weighted more to balance the data as the CNN works best when the training data is balanced.</p> <hr /> <h2 id="train-the-model">Train the model</h2> <h3 id="defining-callbacks">Defining callbacks</h3> <p>The checkpoint callback saves the best weights of the model, so next time we want to use the model, we do not have to spend time training it. The early stopping callback stops the training process when the model starts becoming stagnant, or even worse, when the model starts overfitting.</p> <div class="codehilite"><pre><span></span><code><span class="n">checkpoint_cb</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">callbacks</span><span class="o">.</span><span class="n">ModelCheckpoint</span><span class="p">(</span><span class="s2">&quot;xray_model.keras&quot;</span><span class="p">,</span> <span class="n">save_best_only</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="n">early_stopping_cb</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">callbacks</span><span class="o">.</span><span class="n">EarlyStopping</span><span class="p">(</span> <span class="n">patience</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">restore_best_weights</span><span class="o">=</span><span class="kc">True</span> <span class="p">)</span> </code></pre></div> <p>We also want to tune our learning rate. Too high of a learning rate will cause the model to diverge. Too small of a learning rate will cause the model to be too slow. We implement the exponential learning rate scheduling method below.</p> <div class="codehilite"><pre><span></span><code><span class="n">initial_learning_rate</span> <span class="o">=</span> <span class="mf">0.015</span> <span class="n">lr_schedule</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">schedules</span><span class="o">.</span><span class="n">ExponentialDecay</span><span class="p">(</span> <span class="n">initial_learning_rate</span><span class="p">,</span> <span class="n">decay_steps</span><span class="o">=</span><span class="mi">100000</span><span class="p">,</span> <span class="n">decay_rate</span><span class="o">=</span><span class="mf">0.96</span><span class="p">,</span> <span class="n">staircase</span><span class="o">=</span><span class="kc">True</span> <span class="p">)</span> </code></pre></div> <h3 id="fit-the-model">Fit the model</h3> <p>For our metrics, we want to include precision and recall as they will provide use with a more informed picture of how good our model is. Accuracy tells us what fraction of the labels is correct. Since our data is not balanced, accuracy might give a skewed sense of a good model (i.e. a model that always predicts PNEUMONIA will be 74% accurate but is not a good model).</p> <p>Precision is the number of true positives (TP) over the sum of TP and false positives (FP). It shows what fraction of labeled positives are actually correct.</p> <p>Recall is the number of TP over the sum of TP and false negatves (FN). It shows what fraction of actual positives are correct.</p> <p>Since there are only two possible labels for the image, we will be using the binary crossentropy loss. When we fit the model, remember to specify the class weights, which we defined earlier. Because we are using a TPU, training will be quick - less than 2 minutes.</p> <div class="codehilite"><pre><span></span><code><span class="k">with</span> <span class="n">strategy</span><span class="o">.</span><span class="n">scope</span><span class="p">():</span> <span class="n">model</span> <span class="o">=</span> <span class="n">build_model</span><span class="p">()</span> <span class="n">METRICS</span> <span class="o">=</span> <span class="p">[</span> <span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">BinaryAccuracy</span><span class="p">(),</span> <span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">Precision</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">&quot;precision&quot;</span><span class="p">),</span> <span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">Recall</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">&quot;recall&quot;</span><span class="p">),</span> <span class="p">]</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">learning_rate</span><span class="o">=</span><span class="n">lr_schedule</span><span class="p">),</span> <span class="n">loss</span><span class="o">=</span><span class="s2">&quot;binary_crossentropy&quot;</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="n">METRICS</span><span class="p">,</span> <span class="p">)</span> <span class="n">history</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span> <span class="n">train_ds</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">100</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="n">val_ds</span><span class="p">,</span> <span class="n">class_weight</span><span class="o">=</span><span class="n">class_weight</span><span class="p">,</span> <span class="n">callbacks</span><span class="o">=</span><span class="p">[</span><span class="n">checkpoint_cb</span><span class="p">,</span> <span class="n">early_stopping_cb</span><span class="p">],</span> <span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 1/100 WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.data.Iterator.get_next_as_optional()` instead. WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version. Instructions for updating: Use `tf.data.Iterator.get_next_as_optional()` instead. 21/21 [==============================] - 12s 568ms/step - loss: 0.5857 - binary_accuracy: 0.6960 - precision: 0.8887 - recall: 0.6733 - val_loss: 34.0149 - val_binary_accuracy: 0.7180 - val_precision: 0.7180 - val_recall: 1.0000 Epoch 2/100 21/21 [==============================] - 3s 128ms/step - loss: 0.2916 - binary_accuracy: 0.8755 - precision: 0.9540 - recall: 0.8738 - val_loss: 97.5194 - val_binary_accuracy: 0.7180 - val_precision: 0.7180 - val_recall: 1.0000 Epoch 3/100 21/21 [==============================] - 4s 167ms/step - loss: 0.2384 - binary_accuracy: 0.9002 - precision: 0.9663 - recall: 0.8964 - val_loss: 27.7902 - val_binary_accuracy: 0.7180 - val_precision: 0.7180 - val_recall: 1.0000 Epoch 4/100 21/21 [==============================] - 4s 173ms/step - loss: 0.2046 - binary_accuracy: 0.9145 - precision: 0.9725 - recall: 0.9102 - val_loss: 10.8302 - val_binary_accuracy: 0.7180 - val_precision: 0.7180 - val_recall: 1.0000 Epoch 5/100 21/21 [==============================] - 4s 174ms/step - loss: 0.1841 - binary_accuracy: 0.9279 - precision: 0.9733 - recall: 0.9279 - val_loss: 3.5860 - val_binary_accuracy: 0.7103 - val_precision: 0.7162 - val_recall: 0.9879 Epoch 6/100 21/21 [==============================] - 4s 185ms/step - loss: 0.1600 - binary_accuracy: 0.9362 - precision: 0.9791 - recall: 0.9337 - val_loss: 0.3014 - val_binary_accuracy: 0.8895 - val_precision: 0.8973 - val_recall: 0.9555 Epoch 7/100 21/21 [==============================] - 3s 130ms/step - loss: 0.1567 - binary_accuracy: 0.9393 - precision: 0.9798 - recall: 0.9372 - val_loss: 0.6763 - val_binary_accuracy: 0.7810 - val_precision: 0.7760 - val_recall: 0.9771 Epoch 8/100 21/21 [==============================] - 3s 131ms/step - loss: 0.1532 - binary_accuracy: 0.9421 - precision: 0.9825 - recall: 0.9385 - val_loss: 0.3169 - val_binary_accuracy: 0.8895 - val_precision: 0.8684 - val_recall: 0.9973 Epoch 9/100 21/21 [==============================] - 4s 184ms/step - loss: 0.1457 - binary_accuracy: 0.9431 - precision: 0.9822 - recall: 0.9401 - val_loss: 0.2064 - val_binary_accuracy: 0.9273 - val_precision: 0.9840 - val_recall: 0.9136 Epoch 10/100 21/21 [==============================] - 3s 132ms/step - loss: 0.1201 - binary_accuracy: 0.9521 - precision: 0.9869 - recall: 0.9479 - val_loss: 0.4364 - val_binary_accuracy: 0.8605 - val_precision: 0.8443 - val_recall: 0.9879 Epoch 11/100 21/21 [==============================] - 3s 127ms/step - loss: 0.1200 - binary_accuracy: 0.9510 - precision: 0.9863 - recall: 0.9469 - val_loss: 0.5197 - val_binary_accuracy: 0.8508 - val_precision: 1.0000 - val_recall: 0.7922 Epoch 12/100 21/21 [==============================] - 4s 186ms/step - loss: 0.1077 - binary_accuracy: 0.9581 - precision: 0.9870 - recall: 0.9559 - val_loss: 0.1349 - val_binary_accuracy: 0.9486 - val_precision: 0.9587 - val_recall: 0.9703 Epoch 13/100 21/21 [==============================] - 4s 173ms/step - loss: 0.0918 - binary_accuracy: 0.9650 - precision: 0.9914 - recall: 0.9611 - val_loss: 0.0926 - val_binary_accuracy: 0.9700 - val_precision: 0.9837 - val_recall: 0.9744 Epoch 14/100 21/21 [==============================] - 3s 130ms/step - loss: 0.0996 - binary_accuracy: 0.9612 - precision: 0.9913 - recall: 0.9559 - val_loss: 0.1811 - val_binary_accuracy: 0.9419 - val_precision: 0.9956 - val_recall: 0.9231 Epoch 15/100 21/21 [==============================] - 3s 129ms/step - loss: 0.0898 - binary_accuracy: 0.9643 - precision: 0.9901 - recall: 0.9614 - val_loss: 0.1525 - val_binary_accuracy: 0.9486 - val_precision: 0.9986 - val_recall: 0.9298 Epoch 16/100 21/21 [==============================] - 3s 128ms/step - loss: 0.0941 - binary_accuracy: 0.9621 - precision: 0.9904 - recall: 0.9582 - val_loss: 0.5101 - val_binary_accuracy: 0.8527 - val_precision: 1.0000 - val_recall: 0.7949 Epoch 17/100 21/21 [==============================] - 3s 125ms/step - loss: 0.0798 - binary_accuracy: 0.9636 - precision: 0.9897 - recall: 0.9607 - val_loss: 0.1239 - val_binary_accuracy: 0.9622 - val_precision: 0.9875 - val_recall: 0.9595 Epoch 18/100 21/21 [==============================] - 3s 126ms/step - loss: 0.0821 - binary_accuracy: 0.9657 - precision: 0.9911 - recall: 0.9623 - val_loss: 0.1597 - val_binary_accuracy: 0.9322 - val_precision: 0.9956 - val_recall: 0.9096 Epoch 19/100 21/21 [==============================] - 3s 143ms/step - loss: 0.0800 - binary_accuracy: 0.9657 - precision: 0.9917 - recall: 0.9617 - val_loss: 0.2538 - val_binary_accuracy: 0.9109 - val_precision: 1.0000 - val_recall: 0.8758 Epoch 20/100 21/21 [==============================] - 3s 127ms/step - loss: 0.0605 - binary_accuracy: 0.9738 - precision: 0.9950 - recall: 0.9694 - val_loss: 0.6594 - val_binary_accuracy: 0.8566 - val_precision: 1.0000 - val_recall: 0.8003 Epoch 21/100 21/21 [==============================] - 4s 167ms/step - loss: 0.0726 - binary_accuracy: 0.9733 - precision: 0.9937 - recall: 0.9701 - val_loss: 0.0593 - val_binary_accuracy: 0.9816 - val_precision: 0.9945 - val_recall: 0.9798 Epoch 22/100 21/21 [==============================] - 3s 126ms/step - loss: 0.0577 - binary_accuracy: 0.9783 - precision: 0.9951 - recall: 0.9755 - val_loss: 0.1087 - val_binary_accuracy: 0.9729 - val_precision: 0.9931 - val_recall: 0.9690 Epoch 23/100 21/21 [==============================] - 3s 125ms/step - loss: 0.0652 - binary_accuracy: 0.9729 - precision: 0.9924 - recall: 0.9707 - val_loss: 1.8465 - val_binary_accuracy: 0.7180 - val_precision: 0.7180 - val_recall: 1.0000 Epoch 24/100 21/21 [==============================] - 3s 124ms/step - loss: 0.0538 - binary_accuracy: 0.9783 - precision: 0.9951 - recall: 0.9755 - val_loss: 1.5769 - val_binary_accuracy: 0.7180 - val_precision: 0.7180 - val_recall: 1.0000 Epoch 25/100 21/21 [==============================] - 4s 167ms/step - loss: 0.0549 - binary_accuracy: 0.9776 - precision: 0.9954 - recall: 0.9743 - val_loss: 0.0590 - val_binary_accuracy: 0.9777 - val_precision: 0.9904 - val_recall: 0.9784 Epoch 26/100 21/21 [==============================] - 3s 131ms/step - loss: 0.0677 - binary_accuracy: 0.9719 - precision: 0.9924 - recall: 0.9694 - val_loss: 2.6008 - val_binary_accuracy: 0.6928 - val_precision: 0.9977 - val_recall: 0.5735 Epoch 27/100 21/21 [==============================] - 3s 127ms/step - loss: 0.0469 - binary_accuracy: 0.9833 - precision: 0.9971 - recall: 0.9804 - val_loss: 1.0184 - val_binary_accuracy: 0.8605 - val_precision: 0.9983 - val_recall: 0.8070 Epoch 28/100 21/21 [==============================] - 3s 126ms/step - loss: 0.0501 - binary_accuracy: 0.9790 - precision: 0.9961 - recall: 0.9755 - val_loss: 0.3737 - val_binary_accuracy: 0.9089 - val_precision: 0.9954 - val_recall: 0.8772 Epoch 29/100 21/21 [==============================] - 3s 128ms/step - loss: 0.0548 - binary_accuracy: 0.9798 - precision: 0.9941 - recall: 0.9784 - val_loss: 1.2928 - val_binary_accuracy: 0.7907 - val_precision: 1.0000 - val_recall: 0.7085 Epoch 30/100 21/21 [==============================] - 3s 129ms/step - loss: 0.0370 - binary_accuracy: 0.9860 - precision: 0.9980 - recall: 0.9829 - val_loss: 0.1370 - val_binary_accuracy: 0.9612 - val_precision: 0.9972 - val_recall: 0.9487 Epoch 31/100 21/21 [==============================] - 3s 125ms/step - loss: 0.0585 - binary_accuracy: 0.9819 - precision: 0.9951 - recall: 0.9804 - val_loss: 1.1955 - val_binary_accuracy: 0.6870 - val_precision: 0.9976 - val_recall: 0.5655 Epoch 32/100 21/21 [==============================] - 3s 140ms/step - loss: 0.0813 - binary_accuracy: 0.9695 - precision: 0.9934 - recall: 0.9652 - val_loss: 1.0394 - val_binary_accuracy: 0.8576 - val_precision: 0.9853 - val_recall: 0.8138 Epoch 33/100 21/21 [==============================] - 3s 128ms/step - loss: 0.1111 - binary_accuracy: 0.9555 - precision: 0.9870 - recall: 0.9524 - val_loss: 4.9438 - val_binary_accuracy: 0.5911 - val_precision: 1.0000 - val_recall: 0.4305 Epoch 34/100 21/21 [==============================] - 3s 130ms/step - loss: 0.0680 - binary_accuracy: 0.9726 - precision: 0.9921 - recall: 0.9707 - val_loss: 2.8822 - val_binary_accuracy: 0.7267 - val_precision: 0.9978 - val_recall: 0.6208 Epoch 35/100 21/21 [==============================] - 4s 187ms/step - loss: 0.0784 - binary_accuracy: 0.9712 - precision: 0.9892 - recall: 0.9717 - val_loss: 0.3940 - val_binary_accuracy: 0.9390 - val_precision: 0.9942 - val_recall: 0.9204 </code></pre></div> </div> <hr /> <h2 id="visualizing-model-performance">Visualizing model performance</h2> <p>Let's plot the model accuracy and loss for the training and the validating set. Note that no random seed is specified for this notebook. For your notebook, there might be slight variance.</p> <div class="codehilite"><pre><span></span><code><span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">4</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">20</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">ax</span><span class="o">.</span><span class="n">ravel</span><span class="p">()</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">met</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">([</span><span class="s2">&quot;precision&quot;</span><span class="p">,</span> <span class="s2">&quot;recall&quot;</span><span class="p">,</span> <span class="s2">&quot;binary_accuracy&quot;</span><span class="p">,</span> <span class="s2">&quot;loss&quot;</span><span class="p">]):</span> <span class="n">ax</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">history</span><span class="o">.</span><span class="n">history</span><span class="p">[</span><span class="n">met</span><span class="p">])</span> <span class="n">ax</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">history</span><span class="o">.</span><span class="n">history</span><span class="p">[</span><span class="s2">&quot;val_&quot;</span> <span class="o">+</span> <span class="n">met</span><span class="p">])</span> <span class="n">ax</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">&quot;Model </span><span class="si">{}</span><span class="s2">&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">met</span><span class="p">))</span> <span class="n">ax</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s2">&quot;epochs&quot;</span><span class="p">)</span> <span class="n">ax</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="n">met</span><span class="p">)</span> <span class="n">ax</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">legend</span><span class="p">([</span><span class="s2">&quot;train&quot;</span><span class="p">,</span> <span class="s2">&quot;val&quot;</span><span class="p">])</span> </code></pre></div> <p><img alt="png" src="/img/examples/vision/xray_classification_with_tpus/xray_classification_with_tpus_41_0.png" /></p> <p>We see that the accuracy for our model is around 95%.</p> <hr /> <h2 id="predict-and-evaluate-results">Predict and evaluate results</h2> <p>Let's evaluate the model on our test data!</p> <div class="codehilite"><pre><span></span><code><span class="n">model</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">test_ds</span><span class="p">,</span> <span class="n">return_dict</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>4/4 [==============================] - 3s 708ms/step - loss: 0.9718 - binary_accuracy: 0.7901 - precision: 0.7524 - recall: 0.9897 {&#39;binary_accuracy&#39;: 0.7900640964508057, &#39;loss&#39;: 0.9717951416969299, &#39;precision&#39;: 0.752436637878418, &#39;recall&#39;: 0.9897436499595642} </code></pre></div> </div> <p>We see that our accuracy on our test data is lower than the accuracy for our validating set. This may indicate overfitting.</p> <p>Our recall is greater than our precision, indicating that almost all pneumonia images are correctly identified but some normal images are falsely identified. We should aim to increase our precision.</p> <div class="codehilite"><pre><span></span><code><span class="k">for</span> <span class="n">image</span><span class="p">,</span> <span class="n">label</span> <span class="ow">in</span> <span class="n">test_ds</span><span class="o">.</span><span class="n">take</span><span class="p">(</span><span class="mi">1</span><span class="p">):</span> <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">image</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">/</span> <span class="mf">255.0</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="n">CLASS_NAMES</span><span class="p">[</span><span class="n">label</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">numpy</span><span class="p">()])</span> <span class="n">prediction</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">test_ds</span><span class="o">.</span><span class="n">take</span><span class="p">(</span><span class="mi">1</span><span class="p">))[</span><span class="mi">0</span><span class="p">]</span> <span class="n">scores</span> <span class="o">=</span> <span class="p">[</span><span class="mi">1</span> <span class="o">-</span> <span class="n">prediction</span><span class="p">,</span> <span class="n">prediction</span><span class="p">]</span> <span class="k">for</span> <span class="n">score</span><span class="p">,</span> <span class="n">name</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="n">scores</span><span class="p">,</span> <span class="n">CLASS_NAMES</span><span class="p">):</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;This image is </span><span class="si">%.2f</span><span class="s2"> percent </span><span class="si">%s</span><span class="s2">&quot;</span> <span class="o">%</span> <span class="p">((</span><span class="mi">100</span> <span class="o">*</span> <span class="n">score</span><span class="p">),</span> <span class="n">name</span><span class="p">))</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:3: DeprecationWarning: In future, it will be an error for &#39;np.bool_&#39; scalars to be interpreted as an index This is separate from the ipykernel package so we can avoid doing imports until This image is 47.19 percent NORMAL This image is 52.81 percent PNEUMONIA </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/xray_classification_with_tpus/xray_classification_with_tpus_46_2.png" /></p> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#pneumonia-classification-on-tpu'>Pneumonia Classification on TPU</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction--setup'>Introduction + Set-up</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#load-the-data'>Load the data</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#visualize-the-dataset'>Visualize the dataset</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#build-the-cnn'>Build the CNN</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#correct-for-data-imbalance'>Correct for data imbalance</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#train-the-model'>Train the model</a> </div> <div class='k-outline-depth-3'> <a href='#defining-callbacks'>Defining callbacks</a> </div> <div class='k-outline-depth-3'> <a href='#fit-the-model'>Fit the model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#visualizing-model-performance'>Visualizing model performance</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#predict-and-evaluate-results'>Predict and evaluate results</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>

Pages: 1 2 3 4 5 6 7 8 9 10