CINXE.COM
CutMix data augmentation for image classification
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/vision/cutmix/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: CutMix data augmentation for image classification"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: CutMix data augmentation for image classification"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>CutMix data augmentation for image classification</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink active" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink2" href="/examples/vision/image_classification_from_scratch/">Image classification from scratch</a> <a class="nav-sublink2" href="/examples/vision/mnist_convnet/">Simple MNIST convnet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_efficientnet_fine_tuning/">Image classification via fine-tuning with EfficientNet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_with_vision_transformer/">Image classification with Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/attention_mil_classification/">Classification using Attention-based Deep Multiple Instance Learning</a> <a class="nav-sublink2" href="/examples/vision/mlp_image_classification/">Image classification with modern MLP models</a> <a class="nav-sublink2" href="/examples/vision/mobilevit/">A mobile-friendly Transformer-based model for image classification</a> <a class="nav-sublink2" href="/examples/vision/xray_classification_with_tpus/">Pneumonia Classification on TPU</a> <a class="nav-sublink2" href="/examples/vision/cct/">Compact Convolutional Transformers</a> <a class="nav-sublink2" href="/examples/vision/convmixer/">Image classification with ConvMixer</a> <a class="nav-sublink2" href="/examples/vision/eanet/">Image classification with EANet (External Attention Transformer)</a> <a class="nav-sublink2" href="/examples/vision/involution/">Involutional neural networks</a> <a class="nav-sublink2" href="/examples/vision/perceiver_image_classification/">Image classification with Perceiver</a> <a class="nav-sublink2" href="/examples/vision/reptile/">Few-Shot learning with Reptile</a> <a class="nav-sublink2" href="/examples/vision/semisupervised_simclr/">Semi-supervised image classification using contrastive pretraining with SimCLR</a> <a class="nav-sublink2" href="/examples/vision/swin_transformers/">Image classification with Swin Transformers</a> <a class="nav-sublink2" href="/examples/vision/vit_small_ds/">Train a Vision Transformer on small datasets</a> <a class="nav-sublink2" href="/examples/vision/shiftvit/">A Vision Transformer without Attention</a> <a class="nav-sublink2" href="/examples/vision/image_classification_using_global_context_vision_transformer/">Image Classification using Global Context Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/oxford_pets_image_segmentation/">Image segmentation with a U-Net-like architecture</a> <a class="nav-sublink2" href="/examples/vision/deeplabv3_plus/">Multiclass semantic segmentation using DeepLabV3+</a> <a class="nav-sublink2" href="/examples/vision/basnet_segmentation/">Highly accurate boundaries segmentation using BASNet</a> <a class="nav-sublink2" href="/examples/vision/fully_convolutional_network/">Image Segmentation using Composable Fully-Convolutional Networks</a> <a class="nav-sublink2" href="/examples/vision/retinanet/">Object Detection with RetinaNet</a> <a class="nav-sublink2" href="/examples/vision/keypoint_detection/">Keypoint Detection with Transfer Learning</a> <a class="nav-sublink2" href="/examples/vision/object_detection_using_vision_transformer/">Object detection with Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/3D_image_classification/">3D image classification from CT scans</a> <a class="nav-sublink2" href="/examples/vision/depth_estimation/">Monocular depth estimation</a> <a class="nav-sublink2" href="/examples/vision/nerf/">3D volumetric rendering with NeRF</a> <a class="nav-sublink2" href="/examples/vision/pointnet_segmentation/">Point cloud segmentation with PointNet</a> <a class="nav-sublink2" href="/examples/vision/pointnet/">Point cloud classification</a> <a class="nav-sublink2" href="/examples/vision/captcha_ocr/">OCR model for reading Captchas</a> <a class="nav-sublink2" href="/examples/vision/handwriting_recognition/">Handwriting recognition</a> <a class="nav-sublink2" href="/examples/vision/autoencoder/">Convolutional autoencoder for image denoising</a> <a class="nav-sublink2" href="/examples/vision/mirnet/">Low-light image enhancement using MIRNet</a> <a class="nav-sublink2" href="/examples/vision/super_resolution_sub_pixel/">Image Super-Resolution using an Efficient Sub-Pixel CNN</a> <a class="nav-sublink2" href="/examples/vision/edsr/">Enhanced Deep Residual Networks for single-image super-resolution</a> <a class="nav-sublink2" href="/examples/vision/zero_dce/">Zero-DCE for low-light image enhancement</a> <a class="nav-sublink2 active" href="/examples/vision/cutmix/">CutMix data augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/mixup/">MixUp augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/randaugment/">RandAugment for Image Classification for Improved Robustness</a> <a class="nav-sublink2" href="/examples/vision/image_captioning/">Image captioning</a> <a class="nav-sublink2" href="/examples/vision/nl_image_search/">Natural language image search with a Dual Encoder</a> <a class="nav-sublink2" href="/examples/vision/visualizing_what_convnets_learn/">Visualizing what convnets learn</a> <a class="nav-sublink2" href="/examples/vision/integrated_gradients/">Model interpretability with Integrated Gradients</a> <a class="nav-sublink2" href="/examples/vision/probing_vits/">Investigating Vision Transformer representations</a> <a class="nav-sublink2" href="/examples/vision/grad_cam/">Grad-CAM class activation visualization</a> <a class="nav-sublink2" href="/examples/vision/near_dup_search/">Near-duplicate image search</a> <a class="nav-sublink2" href="/examples/vision/semantic_image_clustering/">Semantic Image Clustering</a> <a class="nav-sublink2" href="/examples/vision/siamese_contrastive/">Image similarity estimation using a Siamese Network with a contrastive loss</a> <a class="nav-sublink2" href="/examples/vision/siamese_network/">Image similarity estimation using a Siamese Network with a triplet loss</a> <a class="nav-sublink2" href="/examples/vision/metric_learning/">Metric learning for image similarity search</a> <a class="nav-sublink2" href="/examples/vision/metric_learning_tf_similarity/">Metric learning for image similarity search using TensorFlow Similarity</a> <a class="nav-sublink2" href="/examples/vision/nnclr/">Self-supervised contrastive learning with NNCLR</a> <a class="nav-sublink2" href="/examples/vision/video_classification/">Video Classification with a CNN-RNN Architecture</a> <a class="nav-sublink2" href="/examples/vision/conv_lstm/">Next-Frame Video Prediction with Convolutional LSTMs</a> <a class="nav-sublink2" href="/examples/vision/video_transformers/">Video Classification with Transformers</a> <a class="nav-sublink2" href="/examples/vision/vivit/">Video Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/bit/">Image Classification using BigTransfer (BiT)</a> <a class="nav-sublink2" href="/examples/vision/gradient_centralization/">Gradient Centralization for Better Training Performance</a> <a class="nav-sublink2" href="/examples/vision/token_learner/">Learning to tokenize in Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/knowledge_distillation/">Knowledge Distillation</a> <a class="nav-sublink2" href="/examples/vision/fixres/">FixRes: Fixing train-test resolution discrepancy</a> <a class="nav-sublink2" href="/examples/vision/cait/">Class Attention Image Transformers with LayerScale</a> <a class="nav-sublink2" href="/examples/vision/patch_convnet/">Augmenting convnets with aggregated attention</a> <a class="nav-sublink2" href="/examples/vision/learnable_resizer/">Learning to Resize</a> <a class="nav-sublink2" href="/examples/vision/adamatch/">Semi-supervision and domain adaptation with AdaMatch</a> <a class="nav-sublink2" href="/examples/vision/barlow_twins/">Barlow Twins for Contrastive SSL</a> <a class="nav-sublink2" href="/examples/vision/consistency_training/">Consistency training with supervision</a> <a class="nav-sublink2" href="/examples/vision/deit/">Distilling Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/focal_modulation_network/">Focal Modulation: A replacement for Self-Attention</a> <a class="nav-sublink2" href="/examples/vision/forwardforward/">Using the Forward-Forward Algorithm for Image Classification</a> <a class="nav-sublink2" href="/examples/vision/masked_image_modeling/">Masked image modeling with Autoencoders</a> <a class="nav-sublink2" href="/examples/vision/sam/">Segment Anything Model with 🤗Transformers</a> <a class="nav-sublink2" href="/examples/vision/segformer/">Semantic segmentation with SegFormer and Hugging Face Transformers</a> <a class="nav-sublink2" href="/examples/vision/simsiam/">Self-supervised contrastive learning with SimSiam</a> <a class="nav-sublink2" href="/examples/vision/supervised-contrastive-learning/">Supervised Contrastive Learning</a> <a class="nav-sublink2" href="/examples/vision/temporal_latent_bottleneck/">When Recurrence meets Transformers</a> <a class="nav-sublink2" href="/examples/vision/yolov8/">Efficient Object Detection with YOLOV8 and KerasCV</a> <a class="nav-sublink" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparameter Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> <a class="nav-link" href="/keras_cv/" role="tab" aria-selected="">KerasCV: Computer Vision Workflows</a> <a class="nav-link" href="/keras_nlp/" role="tab" aria-selected="">KerasNLP: Natural Language Workflows</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/vision/'>Computer Vision</a> / CutMix data augmentation for image classification </div> <div class='k-content'> <h1 id="cutmix-data-augmentation-for-image-classification">CutMix data augmentation for image classification</h1> <p><strong>Author:</strong> <a href="https://twitter.com/sayannath2350">Sayan Nath</a><br> <strong>Date created:</strong> 2021/06/08<br> <strong>Last modified:</strong> 2023/11/14<br> <strong>Description:</strong> Data augmentation with CutMix for image classification on CIFAR-10.</p> <div class='example_version_banner keras_3'>ⓘ This example uses Keras 3</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/vision/ipynb/cutmix.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/examples/vision/cutmix.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="introduction">Introduction</h2> <p><em>CutMix</em> is a data augmentation technique that addresses the issue of information loss and inefficiency present in regional dropout strategies. Instead of removing pixels and filling them with black or grey pixels or Gaussian noise, you replace the removed regions with a patch from another image, while the ground truth labels are mixed proportionally to the number of pixels of combined images. CutMix was proposed in <a href="https://arxiv.org/abs/1905.04899">CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features</a> (Yun et al., 2019)</p> <p>It's implemented via the following formulas:</p> <p><img src="https://i.imgur.com/cGvd13V.png" width="200"/></p> <p>where <code>M</code> is the binary mask which indicates the cutout and the fill-in regions from the two randomly drawn images and <code>λ</code> (in <code>[0, 1]</code>) is drawn from a <a href="https://en.wikipedia.org/wiki/Beta_distribution"><code>Beta(α, α)</code> distribution</a></p> <p>The coordinates of bounding boxes are:</p> <p><img src="https://i.imgur.com/eNisep4.png" width="150"/></p> <p>which indicates the cutout and fill-in regions in case of the images. The bounding box sampling is represented by:</p> <p><img src="https://i.imgur.com/Snph9aj.png" width="200"/></p> <p>where <code>rx, ry</code> are randomly drawn from a uniform distribution with upper bound.</p> <hr /> <h2 id="setup">Setup</h2> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> <span class="kn">import</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span> <span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">layers</span> <span class="c1"># TF imports related to tf.data preprocessing</span> <span class="kn">from</span> <span class="nn">tensorflow</span> <span class="kn">import</span> <span class="n">clip_by_value</span> <span class="kn">from</span> <span class="nn">tensorflow</span> <span class="kn">import</span> <span class="n">data</span> <span class="k">as</span> <span class="n">tf_data</span> <span class="kn">from</span> <span class="nn">tensorflow</span> <span class="kn">import</span> <span class="n">image</span> <span class="k">as</span> <span class="n">tf_image</span> <span class="kn">from</span> <span class="nn">tensorflow</span> <span class="kn">import</span> <span class="n">random</span> <span class="k">as</span> <span class="n">tf_random</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">set_random_seed</span><span class="p">(</span><span class="mi">42</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="load-the-cifar10-dataset">Load the CIFAR-10 dataset</h2> <p>In this example, we will use the <a href="https://www.cs.toronto.edu/~kriz/cifar.html">CIFAR-10 image classification dataset</a>.</p> <div class="codehilite"><pre><span></span><code><span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">),</span> <span class="p">(</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">)</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">datasets</span><span class="o">.</span><span class="n">cifar10</span><span class="o">.</span><span class="n">load_data</span><span class="p">()</span> <span class="n">y_train</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">to_categorical</span><span class="p">(</span><span class="n">y_train</span><span class="p">,</span> <span class="n">num_classes</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span> <span class="n">y_test</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">to_categorical</span><span class="p">(</span><span class="n">y_test</span><span class="p">,</span> <span class="n">num_classes</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="n">x_train</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="n">y_train</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="n">x_test</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="n">y_test</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="n">class_names</span> <span class="o">=</span> <span class="p">[</span> <span class="s2">"Airplane"</span><span class="p">,</span> <span class="s2">"Automobile"</span><span class="p">,</span> <span class="s2">"Bird"</span><span class="p">,</span> <span class="s2">"Cat"</span><span class="p">,</span> <span class="s2">"Deer"</span><span class="p">,</span> <span class="s2">"Dog"</span><span class="p">,</span> <span class="s2">"Frog"</span><span class="p">,</span> <span class="s2">"Horse"</span><span class="p">,</span> <span class="s2">"Ship"</span><span class="p">,</span> <span class="s2">"Truck"</span><span class="p">,</span> <span class="p">]</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>(50000, 32, 32, 3) (50000, 10) (10000, 32, 32, 3) (10000, 10) </code></pre></div> </div> <hr /> <h2 id="define-hyperparameters">Define hyperparameters</h2> <div class="codehilite"><pre><span></span><code><span class="n">AUTO</span> <span class="o">=</span> <span class="n">tf_data</span><span class="o">.</span><span class="n">AUTOTUNE</span> <span class="n">BATCH_SIZE</span> <span class="o">=</span> <span class="mi">32</span> <span class="n">IMG_SIZE</span> <span class="o">=</span> <span class="mi">32</span> </code></pre></div> <hr /> <h2 id="define-the-image-preprocessing-function">Define the image preprocessing function</h2> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">preprocess_image</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="n">label</span><span class="p">):</span> <span class="n">image</span> <span class="o">=</span> <span class="n">tf_image</span><span class="o">.</span><span class="n">resize</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="p">(</span><span class="n">IMG_SIZE</span><span class="p">,</span> <span class="n">IMG_SIZE</span><span class="p">))</span> <span class="n">image</span> <span class="o">=</span> <span class="n">tf_image</span><span class="o">.</span><span class="n">convert_image_dtype</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="s2">"float32"</span><span class="p">)</span> <span class="o">/</span> <span class="mf">255.0</span> <span class="n">label</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">label</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"float32"</span><span class="p">)</span> <span class="k">return</span> <span class="n">image</span><span class="p">,</span> <span class="n">label</span> </code></pre></div> <hr /> <h2 id="convert-the-data-into-tensorflow-dataset-objects">Convert the data into TensorFlow <code>Dataset</code> objects</h2> <div class="codehilite"><pre><span></span><code><span class="n">train_ds_one</span> <span class="o">=</span> <span class="p">(</span> <span class="n">tf_data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">from_tensor_slices</span><span class="p">((</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">))</span> <span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="mi">1024</span><span class="p">)</span> <span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">preprocess_image</span><span class="p">,</span> <span class="n">num_parallel_calls</span><span class="o">=</span><span class="n">AUTO</span><span class="p">)</span> <span class="p">)</span> <span class="n">train_ds_two</span> <span class="o">=</span> <span class="p">(</span> <span class="n">tf_data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">from_tensor_slices</span><span class="p">((</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">))</span> <span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="mi">1024</span><span class="p">)</span> <span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">preprocess_image</span><span class="p">,</span> <span class="n">num_parallel_calls</span><span class="o">=</span><span class="n">AUTO</span><span class="p">)</span> <span class="p">)</span> <span class="n">train_ds_simple</span> <span class="o">=</span> <span class="n">tf_data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">from_tensor_slices</span><span class="p">((</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">))</span> <span class="n">test_ds</span> <span class="o">=</span> <span class="n">tf_data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">from_tensor_slices</span><span class="p">((</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">))</span> <span class="n">train_ds_simple</span> <span class="o">=</span> <span class="p">(</span> <span class="n">train_ds_simple</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">preprocess_image</span><span class="p">,</span> <span class="n">num_parallel_calls</span><span class="o">=</span><span class="n">AUTO</span><span class="p">)</span> <span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="n">BATCH_SIZE</span><span class="p">)</span> <span class="o">.</span><span class="n">prefetch</span><span class="p">(</span><span class="n">AUTO</span><span class="p">)</span> <span class="p">)</span> <span class="c1"># Combine two shuffled datasets from the same training data.</span> <span class="n">train_ds</span> <span class="o">=</span> <span class="n">tf_data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">zip</span><span class="p">((</span><span class="n">train_ds_one</span><span class="p">,</span> <span class="n">train_ds_two</span><span class="p">))</span> <span class="n">test_ds</span> <span class="o">=</span> <span class="p">(</span> <span class="n">test_ds</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">preprocess_image</span><span class="p">,</span> <span class="n">num_parallel_calls</span><span class="o">=</span><span class="n">AUTO</span><span class="p">)</span> <span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="n">BATCH_SIZE</span><span class="p">)</span> <span class="o">.</span><span class="n">prefetch</span><span class="p">(</span><span class="n">AUTO</span><span class="p">)</span> <span class="p">)</span> </code></pre></div> <hr /> <h2 id="define-the-cutmix-data-augmentation-function">Define the CutMix data augmentation function</h2> <p>The CutMix function takes two <code>image</code> and <code>label</code> pairs to perform the augmentation. It samples <code>λ(l)</code> from the <a href="https://en.wikipedia.org/wiki/Beta_distribution">Beta distribution</a> and returns a bounding box from <code>get_box</code> function. We then crop the second image (<code>image2</code>) and pad this image in the final padded image at the same location.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">sample_beta_distribution</span><span class="p">(</span><span class="n">size</span><span class="p">,</span> <span class="n">concentration_0</span><span class="o">=</span><span class="mf">0.2</span><span class="p">,</span> <span class="n">concentration_1</span><span class="o">=</span><span class="mf">0.2</span><span class="p">):</span> <span class="n">gamma_1_sample</span> <span class="o">=</span> <span class="n">tf_random</span><span class="o">.</span><span class="n">gamma</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">[</span><span class="n">size</span><span class="p">],</span> <span class="n">alpha</span><span class="o">=</span><span class="n">concentration_1</span><span class="p">)</span> <span class="n">gamma_2_sample</span> <span class="o">=</span> <span class="n">tf_random</span><span class="o">.</span><span class="n">gamma</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">[</span><span class="n">size</span><span class="p">],</span> <span class="n">alpha</span><span class="o">=</span><span class="n">concentration_0</span><span class="p">)</span> <span class="k">return</span> <span class="n">gamma_1_sample</span> <span class="o">/</span> <span class="p">(</span><span class="n">gamma_1_sample</span> <span class="o">+</span> <span class="n">gamma_2_sample</span><span class="p">)</span> <span class="k">def</span> <span class="nf">get_box</span><span class="p">(</span><span class="n">lambda_value</span><span class="p">):</span> <span class="n">cut_rat</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="mf">1.0</span> <span class="o">-</span> <span class="n">lambda_value</span><span class="p">)</span> <span class="n">cut_w</span> <span class="o">=</span> <span class="n">IMG_SIZE</span> <span class="o">*</span> <span class="n">cut_rat</span> <span class="c1"># rw</span> <span class="n">cut_w</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">cut_w</span><span class="p">,</span> <span class="s2">"int32"</span><span class="p">)</span> <span class="n">cut_h</span> <span class="o">=</span> <span class="n">IMG_SIZE</span> <span class="o">*</span> <span class="n">cut_rat</span> <span class="c1"># rh</span> <span class="n">cut_h</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">cut_h</span><span class="p">,</span> <span class="s2">"int32"</span><span class="p">)</span> <span class="n">cut_x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">((</span><span class="mi">1</span><span class="p">,),</span> <span class="n">minval</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">maxval</span><span class="o">=</span><span class="n">IMG_SIZE</span><span class="p">)</span> <span class="c1"># rx</span> <span class="n">cut_x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">cut_x</span><span class="p">,</span> <span class="s2">"int32"</span><span class="p">)</span> <span class="n">cut_y</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">((</span><span class="mi">1</span><span class="p">,),</span> <span class="n">minval</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">maxval</span><span class="o">=</span><span class="n">IMG_SIZE</span><span class="p">)</span> <span class="c1"># ry</span> <span class="n">cut_y</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">cut_y</span><span class="p">,</span> <span class="s2">"int32"</span><span class="p">)</span> <span class="n">boundaryx1</span> <span class="o">=</span> <span class="n">clip_by_value</span><span class="p">(</span><span class="n">cut_x</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">-</span> <span class="n">cut_w</span> <span class="o">//</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">IMG_SIZE</span><span class="p">)</span> <span class="n">boundaryy1</span> <span class="o">=</span> <span class="n">clip_by_value</span><span class="p">(</span><span class="n">cut_y</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">-</span> <span class="n">cut_h</span> <span class="o">//</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">IMG_SIZE</span><span class="p">)</span> <span class="n">bbx2</span> <span class="o">=</span> <span class="n">clip_by_value</span><span class="p">(</span><span class="n">cut_x</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="n">cut_w</span> <span class="o">//</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">IMG_SIZE</span><span class="p">)</span> <span class="n">bby2</span> <span class="o">=</span> <span class="n">clip_by_value</span><span class="p">(</span><span class="n">cut_y</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="n">cut_h</span> <span class="o">//</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="n">IMG_SIZE</span><span class="p">)</span> <span class="n">target_h</span> <span class="o">=</span> <span class="n">bby2</span> <span class="o">-</span> <span class="n">boundaryy1</span> <span class="k">if</span> <span class="n">target_h</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span> <span class="n">target_h</span> <span class="o">+=</span> <span class="mi">1</span> <span class="n">target_w</span> <span class="o">=</span> <span class="n">bbx2</span> <span class="o">-</span> <span class="n">boundaryx1</span> <span class="k">if</span> <span class="n">target_w</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span> <span class="n">target_w</span> <span class="o">+=</span> <span class="mi">1</span> <span class="k">return</span> <span class="n">boundaryx1</span><span class="p">,</span> <span class="n">boundaryy1</span><span class="p">,</span> <span class="n">target_h</span><span class="p">,</span> <span class="n">target_w</span> <span class="k">def</span> <span class="nf">cutmix</span><span class="p">(</span><span class="n">train_ds_one</span><span class="p">,</span> <span class="n">train_ds_two</span><span class="p">):</span> <span class="p">(</span><span class="n">image1</span><span class="p">,</span> <span class="n">label1</span><span class="p">),</span> <span class="p">(</span><span class="n">image2</span><span class="p">,</span> <span class="n">label2</span><span class="p">)</span> <span class="o">=</span> <span class="n">train_ds_one</span><span class="p">,</span> <span class="n">train_ds_two</span> <span class="n">alpha</span> <span class="o">=</span> <span class="p">[</span><span class="mf">0.25</span><span class="p">]</span> <span class="n">beta</span> <span class="o">=</span> <span class="p">[</span><span class="mf">0.25</span><span class="p">]</span> <span class="c1"># Get a sample from the Beta distribution</span> <span class="n">lambda_value</span> <span class="o">=</span> <span class="n">sample_beta_distribution</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">alpha</span><span class="p">,</span> <span class="n">beta</span><span class="p">)</span> <span class="c1"># Define Lambda</span> <span class="n">lambda_value</span> <span class="o">=</span> <span class="n">lambda_value</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span> <span class="c1"># Get the bounding box offsets, heights and widths</span> <span class="n">boundaryx1</span><span class="p">,</span> <span class="n">boundaryy1</span><span class="p">,</span> <span class="n">target_h</span><span class="p">,</span> <span class="n">target_w</span> <span class="o">=</span> <span class="n">get_box</span><span class="p">(</span><span class="n">lambda_value</span><span class="p">)</span> <span class="c1"># Get a patch from the second image (`image2`)</span> <span class="n">crop2</span> <span class="o">=</span> <span class="n">tf_image</span><span class="o">.</span><span class="n">crop_to_bounding_box</span><span class="p">(</span> <span class="n">image2</span><span class="p">,</span> <span class="n">boundaryy1</span><span class="p">,</span> <span class="n">boundaryx1</span><span class="p">,</span> <span class="n">target_h</span><span class="p">,</span> <span class="n">target_w</span> <span class="p">)</span> <span class="c1"># Pad the `image2` patch (`crop2`) with the same offset</span> <span class="n">image2</span> <span class="o">=</span> <span class="n">tf_image</span><span class="o">.</span><span class="n">pad_to_bounding_box</span><span class="p">(</span> <span class="n">crop2</span><span class="p">,</span> <span class="n">boundaryy1</span><span class="p">,</span> <span class="n">boundaryx1</span><span class="p">,</span> <span class="n">IMG_SIZE</span><span class="p">,</span> <span class="n">IMG_SIZE</span> <span class="p">)</span> <span class="c1"># Get a patch from the first image (`image1`)</span> <span class="n">crop1</span> <span class="o">=</span> <span class="n">tf_image</span><span class="o">.</span><span class="n">crop_to_bounding_box</span><span class="p">(</span> <span class="n">image1</span><span class="p">,</span> <span class="n">boundaryy1</span><span class="p">,</span> <span class="n">boundaryx1</span><span class="p">,</span> <span class="n">target_h</span><span class="p">,</span> <span class="n">target_w</span> <span class="p">)</span> <span class="c1"># Pad the `image1` patch (`crop1`) with the same offset</span> <span class="n">img1</span> <span class="o">=</span> <span class="n">tf_image</span><span class="o">.</span><span class="n">pad_to_bounding_box</span><span class="p">(</span> <span class="n">crop1</span><span class="p">,</span> <span class="n">boundaryy1</span><span class="p">,</span> <span class="n">boundaryx1</span><span class="p">,</span> <span class="n">IMG_SIZE</span><span class="p">,</span> <span class="n">IMG_SIZE</span> <span class="p">)</span> <span class="c1"># Modify the first image by subtracting the patch from `image1`</span> <span class="c1"># (before applying the `image2` patch)</span> <span class="n">image1</span> <span class="o">=</span> <span class="n">image1</span> <span class="o">-</span> <span class="n">img1</span> <span class="c1"># Add the modified `image1` and `image2` together to get the CutMix image</span> <span class="n">image</span> <span class="o">=</span> <span class="n">image1</span> <span class="o">+</span> <span class="n">image2</span> <span class="c1"># Adjust Lambda in accordance to the pixel ration</span> <span class="n">lambda_value</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">-</span> <span class="p">(</span><span class="n">target_w</span> <span class="o">*</span> <span class="n">target_h</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="n">IMG_SIZE</span> <span class="o">*</span> <span class="n">IMG_SIZE</span><span class="p">)</span> <span class="n">lambda_value</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">lambda_value</span><span class="p">,</span> <span class="s2">"float32"</span><span class="p">)</span> <span class="c1"># Combine the labels of both images</span> <span class="n">label</span> <span class="o">=</span> <span class="n">lambda_value</span> <span class="o">*</span> <span class="n">label1</span> <span class="o">+</span> <span class="p">(</span><span class="mi">1</span> <span class="o">-</span> <span class="n">lambda_value</span><span class="p">)</span> <span class="o">*</span> <span class="n">label2</span> <span class="k">return</span> <span class="n">image</span><span class="p">,</span> <span class="n">label</span> </code></pre></div> <p><strong>Note</strong>: we are combining two images to create a single one.</p> <hr /> <h2 id="visualize-the-new-dataset-after-applying-the-cutmix-augmentation">Visualize the new dataset after applying the CutMix augmentation</h2> <div class="codehilite"><pre><span></span><code><span class="c1"># Create the new dataset using our `cutmix` utility</span> <span class="n">train_ds_cmu</span> <span class="o">=</span> <span class="p">(</span> <span class="n">train_ds</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="mi">1024</span><span class="p">)</span> <span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">cutmix</span><span class="p">,</span> <span class="n">num_parallel_calls</span><span class="o">=</span><span class="n">AUTO</span><span class="p">)</span> <span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="n">BATCH_SIZE</span><span class="p">)</span> <span class="o">.</span><span class="n">prefetch</span><span class="p">(</span><span class="n">AUTO</span><span class="p">)</span> <span class="p">)</span> <span class="c1"># Let's preview 9 samples from the dataset</span> <span class="n">image_batch</span><span class="p">,</span> <span class="n">label_batch</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="nb">iter</span><span class="p">(</span><span class="n">train_ds_cmu</span><span class="p">))</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">9</span><span class="p">):</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="n">class_names</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">label_batch</span><span class="p">[</span><span class="n">i</span><span class="p">])])</span> <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">image_batch</span><span class="p">[</span><span class="n">i</span><span class="p">])</span> <span class="n">plt</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">"off"</span><span class="p">)</span> </code></pre></div> <p><img alt="png" src="/img/examples/vision/cutmix/cutmix_16_0.png" /></p> <hr /> <h2 id="define-a-resnet20-model">Define a ResNet-20 model</h2> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">resnet_layer</span><span class="p">(</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">num_filters</span><span class="o">=</span><span class="mi">16</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">strides</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">,</span> <span class="n">batch_normalization</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">conv_first</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="p">):</span> <span class="n">conv</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span> <span class="n">num_filters</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="n">kernel_size</span><span class="p">,</span> <span class="n">strides</span><span class="o">=</span><span class="n">strides</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"same"</span><span class="p">,</span> <span class="n">kernel_initializer</span><span class="o">=</span><span class="s2">"he_normal"</span><span class="p">,</span> <span class="n">kernel_regularizer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">regularizers</span><span class="o">.</span><span class="n">L2</span><span class="p">(</span><span class="mf">1e-4</span><span class="p">),</span> <span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">inputs</span> <span class="k">if</span> <span class="n">conv_first</span><span class="p">:</span> <span class="n">x</span> <span class="o">=</span> <span class="n">conv</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">if</span> <span class="n">batch_normalization</span><span class="p">:</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">BatchNormalization</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> <span class="k">if</span> <span class="n">activation</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Activation</span><span class="p">(</span><span class="n">activation</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="k">else</span><span class="p">:</span> <span class="k">if</span> <span class="n">batch_normalization</span><span class="p">:</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">BatchNormalization</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> <span class="k">if</span> <span class="n">activation</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Activation</span><span class="p">(</span><span class="n">activation</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">conv</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">x</span> <span class="k">def</span> <span class="nf">resnet_v20</span><span class="p">(</span><span class="n">input_shape</span><span class="p">,</span> <span class="n">depth</span><span class="p">,</span> <span class="n">num_classes</span><span class="o">=</span><span class="mi">10</span><span class="p">):</span> <span class="k">if</span> <span class="p">(</span><span class="n">depth</span> <span class="o">-</span> <span class="mi">2</span><span class="p">)</span> <span class="o">%</span> <span class="mi">6</span> <span class="o">!=</span> <span class="mi">0</span><span class="p">:</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">"depth should be 6n+2 (eg 20, 32, 44 in [a])"</span><span class="p">)</span> <span class="c1"># Start model definition.</span> <span class="n">num_filters</span> <span class="o">=</span> <span class="mi">16</span> <span class="n">num_res_blocks</span> <span class="o">=</span> <span class="nb">int</span><span class="p">((</span><span class="n">depth</span> <span class="o">-</span> <span class="mi">2</span><span class="p">)</span> <span class="o">/</span> <span class="mi">6</span><span class="p">)</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="n">input_shape</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">resnet_layer</span><span class="p">(</span><span class="n">inputs</span><span class="o">=</span><span class="n">inputs</span><span class="p">)</span> <span class="c1"># Instantiate the stack of residual units</span> <span class="k">for</span> <span class="n">stack</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">3</span><span class="p">):</span> <span class="k">for</span> <span class="n">res_block</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_res_blocks</span><span class="p">):</span> <span class="n">strides</span> <span class="o">=</span> <span class="mi">1</span> <span class="k">if</span> <span class="n">stack</span> <span class="o">></span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">res_block</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span> <span class="c1"># first layer but not first stack</span> <span class="n">strides</span> <span class="o">=</span> <span class="mi">2</span> <span class="c1"># downsample</span> <span class="n">y</span> <span class="o">=</span> <span class="n">resnet_layer</span><span class="p">(</span><span class="n">inputs</span><span class="o">=</span><span class="n">x</span><span class="p">,</span> <span class="n">num_filters</span><span class="o">=</span><span class="n">num_filters</span><span class="p">,</span> <span class="n">strides</span><span class="o">=</span><span class="n">strides</span><span class="p">)</span> <span class="n">y</span> <span class="o">=</span> <span class="n">resnet_layer</span><span class="p">(</span><span class="n">inputs</span><span class="o">=</span><span class="n">y</span><span class="p">,</span> <span class="n">num_filters</span><span class="o">=</span><span class="n">num_filters</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span> <span class="k">if</span> <span class="n">stack</span> <span class="o">></span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">res_block</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span> <span class="c1"># first layer but not first stack</span> <span class="c1"># linear projection residual shortcut connection to match</span> <span class="c1"># changed dims</span> <span class="n">x</span> <span class="o">=</span> <span class="n">resnet_layer</span><span class="p">(</span> <span class="n">inputs</span><span class="o">=</span><span class="n">x</span><span class="p">,</span> <span class="n">num_filters</span><span class="o">=</span><span class="n">num_filters</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">strides</span><span class="o">=</span><span class="n">strides</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">batch_normalization</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">add</span><span class="p">([</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">])</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Activation</span><span class="p">(</span><span class="s2">"relu"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">num_filters</span> <span class="o">*=</span> <span class="mi">2</span> <span class="c1"># Add classifier on top.</span> <span class="c1"># v1 does not use BN after last shortcut connection-ReLU</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">AveragePooling2D</span><span class="p">(</span><span class="n">pool_size</span><span class="o">=</span><span class="mi">8</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">y</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Flatten</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span> <span class="n">num_classes</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"softmax"</span><span class="p">,</span> <span class="n">kernel_initializer</span><span class="o">=</span><span class="s2">"he_normal"</span> <span class="p">)(</span><span class="n">y</span><span class="p">)</span> <span class="c1"># Instantiate model.</span> <span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="o">=</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="o">=</span><span class="n">outputs</span><span class="p">)</span> <span class="k">return</span> <span class="n">model</span> <span class="k">def</span> <span class="nf">training_model</span><span class="p">():</span> <span class="k">return</span> <span class="n">resnet_v20</span><span class="p">((</span><span class="mi">32</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="mi">20</span><span class="p">)</span> <span class="n">initial_model</span> <span class="o">=</span> <span class="n">training_model</span><span class="p">()</span> <span class="n">initial_model</span><span class="o">.</span><span class="n">save_weights</span><span class="p">(</span><span class="s2">"initial_weights.weights.h5"</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="train-the-model-with-the-dataset-augmented-by-cutmix">Train the model with the dataset augmented by CutMix</h2> <div class="codehilite"><pre><span></span><code><span class="n">model</span> <span class="o">=</span> <span class="n">training_model</span><span class="p">()</span> <span class="n">model</span><span class="o">.</span><span class="n">load_weights</span><span class="p">(</span><span class="s2">"initial_weights.weights.h5"</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">loss</span><span class="o">=</span><span class="s2">"categorical_crossentropy"</span><span class="p">,</span> <span class="n">optimizer</span><span class="o">=</span><span class="s2">"adam"</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="s2">"accuracy"</span><span class="p">])</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">train_ds_cmu</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="n">test_ds</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">15</span><span class="p">)</span> <span class="n">test_loss</span><span class="p">,</span> <span class="n">test_accuracy</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">test_ds</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Test accuracy: </span><span class="si">{:.2f}</span><span class="s2">%"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">test_accuracy</span> <span class="o">*</span> <span class="mi">100</span><span class="p">))</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 1/15 10/1563 [37m━━━━━━━━━━━━━━━━━━━━ 19s 13ms/step - accuracy: 0.0795 - loss: 5.3035 WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1699988196.560261 362411 device_compiler.h:187] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process. 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 64s 27ms/step - accuracy: 0.3148 - loss: 2.1918 - val_accuracy: 0.4067 - val_loss: 1.8339 Epoch 2/15 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 27s 17ms/step - accuracy: 0.4295 - loss: 1.9021 - val_accuracy: 0.5516 - val_loss: 1.4744 Epoch 3/15 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 28s 18ms/step - accuracy: 0.4883 - loss: 1.8076 - val_accuracy: 0.5305 - val_loss: 1.5067 Epoch 4/15 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 27s 17ms/step - accuracy: 0.5243 - loss: 1.7342 - val_accuracy: 0.6303 - val_loss: 1.2822 Epoch 5/15 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 27s 17ms/step - accuracy: 0.5574 - loss: 1.6614 - val_accuracy: 0.5370 - val_loss: 1.5912 Epoch 6/15 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 27s 17ms/step - accuracy: 0.5832 - loss: 1.6167 - val_accuracy: 0.6254 - val_loss: 1.3116 Epoch 7/15 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 26s 17ms/step - accuracy: 0.6045 - loss: 1.5738 - val_accuracy: 0.6101 - val_loss: 1.3408 Epoch 8/15 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 28s 18ms/step - accuracy: 0.6170 - loss: 1.5493 - val_accuracy: 0.6209 - val_loss: 1.2923 Epoch 9/15 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 29s 18ms/step - accuracy: 0.6292 - loss: 1.5299 - val_accuracy: 0.6290 - val_loss: 1.2813 Epoch 10/15 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 28s 18ms/step - accuracy: 0.6394 - loss: 1.5110 - val_accuracy: 0.7234 - val_loss: 1.0608 Epoch 11/15 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 26s 17ms/step - accuracy: 0.6467 - loss: 1.4915 - val_accuracy: 0.7498 - val_loss: 0.9854 Epoch 12/15 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 28s 18ms/step - accuracy: 0.6559 - loss: 1.4785 - val_accuracy: 0.6481 - val_loss: 1.2410 Epoch 13/15 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 26s 17ms/step - accuracy: 0.6596 - loss: 1.4656 - val_accuracy: 0.7551 - val_loss: 0.9784 Epoch 14/15 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 27s 17ms/step - accuracy: 0.6577 - loss: 1.4637 - val_accuracy: 0.6822 - val_loss: 1.1703 Epoch 15/15 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 26s 17ms/step - accuracy: 0.6702 - loss: 1.4445 - val_accuracy: 0.7108 - val_loss: 1.0805 313/313 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - accuracy: 0.7140 - loss: 1.0766 Test accuracy: 71.08% </code></pre></div> </div> <hr /> <h2 id="train-the-model-using-the-original-nonaugmented-dataset">Train the model using the original non-augmented dataset</h2> <div class="codehilite"><pre><span></span><code><span class="n">model</span> <span class="o">=</span> <span class="n">training_model</span><span class="p">()</span> <span class="n">model</span><span class="o">.</span><span class="n">load_weights</span><span class="p">(</span><span class="s2">"initial_weights.weights.h5"</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">loss</span><span class="o">=</span><span class="s2">"categorical_crossentropy"</span><span class="p">,</span> <span class="n">optimizer</span><span class="o">=</span><span class="s2">"adam"</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="s2">"accuracy"</span><span class="p">])</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">train_ds_simple</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="n">test_ds</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">15</span><span class="p">)</span> <span class="n">test_loss</span><span class="p">,</span> <span class="n">test_accuracy</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">test_ds</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Test accuracy: </span><span class="si">{:.2f}</span><span class="s2">%"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">test_accuracy</span> <span class="o">*</span> <span class="mi">100</span><span class="p">))</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 1/15 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 41s 15ms/step - accuracy: 0.3943 - loss: 1.8736 - val_accuracy: 0.5359 - val_loss: 1.4376 Epoch 2/15 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 11s 7ms/step - accuracy: 0.6160 - loss: 1.2407 - val_accuracy: 0.5887 - val_loss: 1.4254 Epoch 3/15 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 11s 7ms/step - accuracy: 0.6927 - loss: 1.0448 - val_accuracy: 0.6102 - val_loss: 1.4850 Epoch 4/15 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 12s 7ms/step - accuracy: 0.7411 - loss: 0.9222 - val_accuracy: 0.6262 - val_loss: 1.3898 Epoch 5/15 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 13s 8ms/step - accuracy: 0.7711 - loss: 0.8439 - val_accuracy: 0.6283 - val_loss: 1.3425 Epoch 6/15 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 12s 8ms/step - accuracy: 0.7983 - loss: 0.7886 - val_accuracy: 0.2460 - val_loss: 5.6869 Epoch 7/15 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 11s 7ms/step - accuracy: 0.8168 - loss: 0.7490 - val_accuracy: 0.1954 - val_loss: 21.7670 Epoch 8/15 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 11s 7ms/step - accuracy: 0.8113 - loss: 0.7779 - val_accuracy: 0.1027 - val_loss: 36.3144 Epoch 9/15 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 11s 7ms/step - accuracy: 0.6592 - loss: 1.4179 - val_accuracy: 0.1025 - val_loss: 40.0770 Epoch 10/15 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 12s 8ms/step - accuracy: 0.5611 - loss: 1.9856 - val_accuracy: 0.1699 - val_loss: 40.6308 Epoch 11/15 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 13s 8ms/step - accuracy: 0.6076 - loss: 1.7795 - val_accuracy: 0.1003 - val_loss: 63.4775 Epoch 12/15 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 12s 7ms/step - accuracy: 0.6175 - loss: 1.8077 - val_accuracy: 0.1099 - val_loss: 21.9148 Epoch 13/15 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 12s 7ms/step - accuracy: 0.6468 - loss: 1.6702 - val_accuracy: 0.1576 - val_loss: 72.7290 Epoch 14/15 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 12s 7ms/step - accuracy: 0.6437 - loss: 1.7858 - val_accuracy: 0.1000 - val_loss: 64.9249 Epoch 15/15 1563/1563 ━━━━━━━━━━━━━━━━━━━━ 13s 8ms/step - accuracy: 0.6587 - loss: 1.7587 - val_accuracy: 0.1000 - val_loss: 138.8463 313/313 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - accuracy: 0.0988 - loss: 139.3117 Test accuracy: 10.00% </code></pre></div> </div> <hr /> <h2 id="notes">Notes</h2> <p>In this example, we trained our model for 15 epochs. In our experiment, the model with CutMix achieves a better accuracy on the CIFAR-10 dataset (77.34% in our experiment) compared to the model that doesn't use the augmentation (66.90%). You may notice it takes less time to train the model with the CutMix augmentation.</p> <p>You can experiment further with the CutMix technique by following the <a href="https://arxiv.org/abs/1905.04899">original paper</a>.</p> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#cutmix-data-augmentation-for-image-classification'>CutMix data augmentation for image classification</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#setup'>Setup</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#load-the-cifar10-dataset'>Load the CIFAR-10 dataset</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#define-hyperparameters'>Define hyperparameters</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#define-the-image-preprocessing-function'>Define the image preprocessing function</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#convert-the-data-into-tensorflow-dataset-objects'>Convert the data into TensorFlow <code>Dataset</code> objects</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#define-the-cutmix-data-augmentation-function'>Define the CutMix data augmentation function</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#visualize-the-new-dataset-after-applying-the-cutmix-augmentation'>Visualize the new dataset after applying the CutMix augmentation</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#define-a-resnet20-model'>Define a ResNet-20 model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#train-the-model-with-the-dataset-augmented-by-cutmix'>Train the model with the dataset augmented by CutMix</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#train-the-model-using-the-original-nonaugmented-dataset'>Train the model using the original non-augmented dataset</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#notes'>Notes</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>