CINXE.COM
RandAugment for Image Classification for Improved Robustness
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/vision/randaugment/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: RandAugment for Image Classification for Improved Robustness"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: RandAugment for Image Classification for Improved Robustness"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>RandAugment for Image Classification for Improved Robustness</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink active" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink2" href="/examples/vision/image_classification_from_scratch/">Image classification from scratch</a> <a class="nav-sublink2" href="/examples/vision/mnist_convnet/">Simple MNIST convnet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_efficientnet_fine_tuning/">Image classification via fine-tuning with EfficientNet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_with_vision_transformer/">Image classification with Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/attention_mil_classification/">Classification using Attention-based Deep Multiple Instance Learning</a> <a class="nav-sublink2" href="/examples/vision/mlp_image_classification/">Image classification with modern MLP models</a> <a class="nav-sublink2" href="/examples/vision/mobilevit/">A mobile-friendly Transformer-based model for image classification</a> <a class="nav-sublink2" href="/examples/vision/xray_classification_with_tpus/">Pneumonia Classification on TPU</a> <a class="nav-sublink2" href="/examples/vision/cct/">Compact Convolutional Transformers</a> <a class="nav-sublink2" href="/examples/vision/convmixer/">Image classification with ConvMixer</a> <a class="nav-sublink2" href="/examples/vision/eanet/">Image classification with EANet (External Attention Transformer)</a> <a class="nav-sublink2" href="/examples/vision/involution/">Involutional neural networks</a> <a class="nav-sublink2" href="/examples/vision/perceiver_image_classification/">Image classification with Perceiver</a> <a class="nav-sublink2" href="/examples/vision/reptile/">Few-Shot learning with Reptile</a> <a class="nav-sublink2" href="/examples/vision/semisupervised_simclr/">Semi-supervised image classification using contrastive pretraining with SimCLR</a> <a class="nav-sublink2" href="/examples/vision/swin_transformers/">Image classification with Swin Transformers</a> <a class="nav-sublink2" href="/examples/vision/vit_small_ds/">Train a Vision Transformer on small datasets</a> <a class="nav-sublink2" href="/examples/vision/shiftvit/">A Vision Transformer without Attention</a> <a class="nav-sublink2" href="/examples/vision/image_classification_using_global_context_vision_transformer/">Image Classification using Global Context Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/oxford_pets_image_segmentation/">Image segmentation with a U-Net-like architecture</a> <a class="nav-sublink2" href="/examples/vision/deeplabv3_plus/">Multiclass semantic segmentation using DeepLabV3+</a> <a class="nav-sublink2" href="/examples/vision/basnet_segmentation/">Highly accurate boundaries segmentation using BASNet</a> <a class="nav-sublink2" href="/examples/vision/fully_convolutional_network/">Image Segmentation using Composable Fully-Convolutional Networks</a> <a class="nav-sublink2" href="/examples/vision/retinanet/">Object Detection with RetinaNet</a> <a class="nav-sublink2" href="/examples/vision/keypoint_detection/">Keypoint Detection with Transfer Learning</a> <a class="nav-sublink2" href="/examples/vision/object_detection_using_vision_transformer/">Object detection with Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/3D_image_classification/">3D image classification from CT scans</a> <a class="nav-sublink2" href="/examples/vision/depth_estimation/">Monocular depth estimation</a> <a class="nav-sublink2" href="/examples/vision/nerf/">3D volumetric rendering with NeRF</a> <a class="nav-sublink2" href="/examples/vision/pointnet_segmentation/">Point cloud segmentation with PointNet</a> <a class="nav-sublink2" href="/examples/vision/pointnet/">Point cloud classification</a> <a class="nav-sublink2" href="/examples/vision/captcha_ocr/">OCR model for reading Captchas</a> <a class="nav-sublink2" href="/examples/vision/handwriting_recognition/">Handwriting recognition</a> <a class="nav-sublink2" href="/examples/vision/autoencoder/">Convolutional autoencoder for image denoising</a> <a class="nav-sublink2" href="/examples/vision/mirnet/">Low-light image enhancement using MIRNet</a> <a class="nav-sublink2" href="/examples/vision/super_resolution_sub_pixel/">Image Super-Resolution using an Efficient Sub-Pixel CNN</a> <a class="nav-sublink2" href="/examples/vision/edsr/">Enhanced Deep Residual Networks for single-image super-resolution</a> <a class="nav-sublink2" href="/examples/vision/zero_dce/">Zero-DCE for low-light image enhancement</a> <a class="nav-sublink2" href="/examples/vision/cutmix/">CutMix data augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/mixup/">MixUp augmentation for image classification</a> <a class="nav-sublink2 active" href="/examples/vision/randaugment/">RandAugment for Image Classification for Improved Robustness</a> <a class="nav-sublink2" href="/examples/vision/image_captioning/">Image captioning</a> <a class="nav-sublink2" href="/examples/vision/nl_image_search/">Natural language image search with a Dual Encoder</a> <a class="nav-sublink2" href="/examples/vision/visualizing_what_convnets_learn/">Visualizing what convnets learn</a> <a class="nav-sublink2" href="/examples/vision/integrated_gradients/">Model interpretability with Integrated Gradients</a> <a class="nav-sublink2" href="/examples/vision/probing_vits/">Investigating Vision Transformer representations</a> <a class="nav-sublink2" href="/examples/vision/grad_cam/">Grad-CAM class activation visualization</a> <a class="nav-sublink2" href="/examples/vision/near_dup_search/">Near-duplicate image search</a> <a class="nav-sublink2" href="/examples/vision/semantic_image_clustering/">Semantic Image Clustering</a> <a class="nav-sublink2" href="/examples/vision/siamese_contrastive/">Image similarity estimation using a Siamese Network with a contrastive loss</a> <a class="nav-sublink2" href="/examples/vision/siamese_network/">Image similarity estimation using a Siamese Network with a triplet loss</a> <a class="nav-sublink2" href="/examples/vision/metric_learning/">Metric learning for image similarity search</a> <a class="nav-sublink2" href="/examples/vision/metric_learning_tf_similarity/">Metric learning for image similarity search using TensorFlow Similarity</a> <a class="nav-sublink2" href="/examples/vision/nnclr/">Self-supervised contrastive learning with NNCLR</a> <a class="nav-sublink2" href="/examples/vision/video_classification/">Video Classification with a CNN-RNN Architecture</a> <a class="nav-sublink2" href="/examples/vision/conv_lstm/">Next-Frame Video Prediction with Convolutional LSTMs</a> <a class="nav-sublink2" href="/examples/vision/video_transformers/">Video Classification with Transformers</a> <a class="nav-sublink2" href="/examples/vision/vivit/">Video Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/bit/">Image Classification using BigTransfer (BiT)</a> <a class="nav-sublink2" href="/examples/vision/gradient_centralization/">Gradient Centralization for Better Training Performance</a> <a class="nav-sublink2" href="/examples/vision/token_learner/">Learning to tokenize in Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/knowledge_distillation/">Knowledge Distillation</a> <a class="nav-sublink2" href="/examples/vision/fixres/">FixRes: Fixing train-test resolution discrepancy</a> <a class="nav-sublink2" href="/examples/vision/cait/">Class Attention Image Transformers with LayerScale</a> <a class="nav-sublink2" href="/examples/vision/patch_convnet/">Augmenting convnets with aggregated attention</a> <a class="nav-sublink2" href="/examples/vision/learnable_resizer/">Learning to Resize</a> <a class="nav-sublink2" href="/examples/vision/adamatch/">Semi-supervision and domain adaptation with AdaMatch</a> <a class="nav-sublink2" href="/examples/vision/barlow_twins/">Barlow Twins for Contrastive SSL</a> <a class="nav-sublink2" href="/examples/vision/consistency_training/">Consistency training with supervision</a> <a class="nav-sublink2" href="/examples/vision/deit/">Distilling Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/focal_modulation_network/">Focal Modulation: A replacement for Self-Attention</a> <a class="nav-sublink2" href="/examples/vision/forwardforward/">Using the Forward-Forward Algorithm for Image Classification</a> <a class="nav-sublink2" href="/examples/vision/masked_image_modeling/">Masked image modeling with Autoencoders</a> <a class="nav-sublink2" href="/examples/vision/sam/">Segment Anything Model with 🤗Transformers</a> <a class="nav-sublink2" href="/examples/vision/segformer/">Semantic segmentation with SegFormer and Hugging Face Transformers</a> <a class="nav-sublink2" href="/examples/vision/simsiam/">Self-supervised contrastive learning with SimSiam</a> <a class="nav-sublink2" href="/examples/vision/supervised-contrastive-learning/">Supervised Contrastive Learning</a> <a class="nav-sublink2" href="/examples/vision/temporal_latent_bottleneck/">When Recurrence meets Transformers</a> <a class="nav-sublink2" href="/examples/vision/yolov8/">Efficient Object Detection with YOLOV8 and KerasCV</a> <a class="nav-sublink" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparameter Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> <a class="nav-link" href="/keras_cv/" role="tab" aria-selected="">KerasCV: Computer Vision Workflows</a> <a class="nav-link" href="/keras_nlp/" role="tab" aria-selected="">KerasNLP: Natural Language Workflows</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/vision/'>Computer Vision</a> / RandAugment for Image Classification for Improved Robustness </div> <div class='k-content'> <h1 id="randaugment-for-image-classification-for-improved-robustness">RandAugment for Image Classification for Improved Robustness</h1> <p><strong>Authors:</strong> <a href="https://twitter.com/RisingSayak">Sayak Paul</a><a href="https://github.com/sachinprasadhs">Sachin Prasad</a><br> <strong>Date created:</strong> 2021/03/13<br> <strong>Last modified:</strong> 2023/12/12<br> <strong>Description:</strong> RandAugment for training an image classification model with improved robustness.</p> <div class='example_version_banner keras_3'>ⓘ This example uses Keras 3</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/vision/ipynb/randaugment.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/examples/vision/randaugment.py"><strong>GitHub source</strong></a></p> <p>Data augmentation is a very useful technique that can help to improve the translational invariance of convolutional neural networks (CNN). RandAugment is a stochastic data augmentation routine for vision data and was proposed in <a href="https://arxiv.org/abs/1909.13719">RandAugment: Practical automated data augmentation with a reduced search space</a>. It is composed of strong augmentation transforms like color jitters, Gaussian blurs, saturations, etc. along with more traditional augmentation transforms such as random crops.</p> <p>These parameters are tuned for a given dataset and a network architecture. The authors of RandAugment also provide pseudocode of RandAugment in the original paper (Figure 2).</p> <p>Recently, it has been a key component of works like <a href="https://arxiv.org/abs/1911.04252">Noisy Student Training</a> and <a href="https://arxiv.org/abs/1904.12848">Unsupervised Data Augmentation for Consistency Training</a>. It has been also central to the success of <a href="https://arxiv.org/abs/1905.11946">EfficientNets</a>.</p> <div class="codehilite"><pre><span></span><code><span class="n">pip</span> <span class="n">install</span> <span class="n">keras</span><span class="o">-</span><span class="n">cv</span> </code></pre></div> <hr /> <h2 id="imports-amp-setup">Imports & setup</h2> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">os</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s2">"KERAS_BACKEND"</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"tensorflow"</span> <span class="kn">import</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="nn">keras_cv</span> <span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">ops</span> <span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">layers</span> <span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="nn">tf</span> <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span> <span class="kn">import</span> <span class="nn">tensorflow_datasets</span> <span class="k">as</span> <span class="nn">tfds</span> <span class="n">tfds</span><span class="o">.</span><span class="n">disable_progress_bar</span><span class="p">()</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">set_random_seed</span><span class="p">(</span><span class="mi">42</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="load-the-cifar10-dataset">Load the CIFAR10 dataset</h2> <p>For this example, we will be using the <a href="https://www.cs.toronto.edu/~kriz/cifar.html">CIFAR10 dataset</a>.</p> <div class="codehilite"><pre><span></span><code><span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">),</span> <span class="p">(</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">)</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">datasets</span><span class="o">.</span><span class="n">cifar10</span><span class="o">.</span><span class="n">load_data</span><span class="p">()</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Total training examples: </span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">x_train</span><span class="p">)</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Total test examples: </span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">x_test</span><span class="p">)</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Total training examples: 50000 Total test examples: 10000 </code></pre></div> </div> <hr /> <h2 id="define-hyperparameters">Define hyperparameters</h2> <div class="codehilite"><pre><span></span><code><span class="n">AUTO</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">AUTOTUNE</span> <span class="n">BATCH_SIZE</span> <span class="o">=</span> <span class="mi">128</span> <span class="n">EPOCHS</span> <span class="o">=</span> <span class="mi">1</span> <span class="n">IMAGE_SIZE</span> <span class="o">=</span> <span class="mi">72</span> </code></pre></div> <hr /> <h2 id="initialize-randaugment-object">Initialize <code>RandAugment</code> object</h2> <p>Now, we will initialize a <code>RandAugment</code> object from the <code>imgaug.augmenters</code> module with the parameters suggested by the RandAugment authors.</p> <div class="codehilite"><pre><span></span><code><span class="n">rand_augment</span> <span class="o">=</span> <span class="n">keras_cv</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">RandAugment</span><span class="p">(</span> <span class="n">value_range</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">255</span><span class="p">),</span> <span class="n">augmentations_per_image</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">magnitude</span><span class="o">=</span><span class="mf">0.8</span> <span class="p">)</span> </code></pre></div> <hr /> <h2 id="create-tensorflow-dataset-objects">Create TensorFlow <code>Dataset</code> objects</h2> <div class="codehilite"><pre><span></span><code><span class="n">train_ds_rand</span> <span class="o">=</span> <span class="p">(</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">from_tensor_slices</span><span class="p">((</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">))</span> <span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="n">BATCH_SIZE</span> <span class="o">*</span> <span class="mi">100</span><span class="p">)</span> <span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="n">BATCH_SIZE</span><span class="p">)</span> <span class="o">.</span><span class="n">map</span><span class="p">(</span> <span class="k">lambda</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">resize</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="p">(</span><span class="n">IMAGE_SIZE</span><span class="p">,</span> <span class="n">IMAGE_SIZE</span><span class="p">)),</span> <span class="n">y</span><span class="p">),</span> <span class="n">num_parallel_calls</span><span class="o">=</span><span class="n">AUTO</span><span class="p">,</span> <span class="p">)</span> <span class="o">.</span><span class="n">map</span><span class="p">(</span> <span class="k">lambda</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="p">(</span><span class="n">rand_augment</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">tf</span><span class="o">.</span><span class="n">uint8</span><span class="p">)),</span> <span class="n">y</span><span class="p">),</span> <span class="n">num_parallel_calls</span><span class="o">=</span><span class="n">AUTO</span><span class="p">,</span> <span class="p">)</span> <span class="o">.</span><span class="n">prefetch</span><span class="p">(</span><span class="n">AUTO</span><span class="p">)</span> <span class="p">)</span> <span class="n">test_ds</span> <span class="o">=</span> <span class="p">(</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">from_tensor_slices</span><span class="p">((</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">))</span> <span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="n">BATCH_SIZE</span><span class="p">)</span> <span class="o">.</span><span class="n">map</span><span class="p">(</span> <span class="k">lambda</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">resize</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="p">(</span><span class="n">IMAGE_SIZE</span><span class="p">,</span> <span class="n">IMAGE_SIZE</span><span class="p">)),</span> <span class="n">y</span><span class="p">),</span> <span class="n">num_parallel_calls</span><span class="o">=</span><span class="n">AUTO</span><span class="p">,</span> <span class="p">)</span> <span class="o">.</span><span class="n">prefetch</span><span class="p">(</span><span class="n">AUTO</span><span class="p">)</span> <span class="p">)</span> </code></pre></div> <p>For comparison purposes, let's also define a simple augmentation pipeline consisting of random flips, random rotations, and random zoomings.</p> <div class="codehilite"><pre><span></span><code><span class="n">simple_aug</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span> <span class="p">[</span> <span class="n">layers</span><span class="o">.</span><span class="n">Resizing</span><span class="p">(</span><span class="n">IMAGE_SIZE</span><span class="p">,</span> <span class="n">IMAGE_SIZE</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">RandomFlip</span><span class="p">(</span><span class="s2">"horizontal"</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">RandomRotation</span><span class="p">(</span><span class="n">factor</span><span class="o">=</span><span class="mf">0.02</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">RandomZoom</span><span class="p">(</span><span class="n">height_factor</span><span class="o">=</span><span class="mf">0.2</span><span class="p">,</span> <span class="n">width_factor</span><span class="o">=</span><span class="mf">0.2</span><span class="p">),</span> <span class="p">]</span> <span class="p">)</span> <span class="c1"># Now, map the augmentation pipeline to our training dataset</span> <span class="n">train_ds_simple</span> <span class="o">=</span> <span class="p">(</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">from_tensor_slices</span><span class="p">((</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">))</span> <span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="n">BATCH_SIZE</span> <span class="o">*</span> <span class="mi">100</span><span class="p">)</span> <span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="n">BATCH_SIZE</span><span class="p">)</span> <span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="p">(</span><span class="n">simple_aug</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="n">y</span><span class="p">),</span> <span class="n">num_parallel_calls</span><span class="o">=</span><span class="n">AUTO</span><span class="p">)</span> <span class="o">.</span><span class="n">prefetch</span><span class="p">(</span><span class="n">AUTO</span><span class="p">)</span> <span class="p">)</span> </code></pre></div> <hr /> <h2 id="visualize-the-dataset-augmented-with-randaugment">Visualize the dataset augmented with RandAugment</h2> <div class="codehilite"><pre><span></span><code><span class="n">sample_images</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="nb">iter</span><span class="p">(</span><span class="n">train_ds_rand</span><span class="p">))</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">image</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">sample_images</span><span class="p">[:</span><span class="mi">9</span><span class="p">]):</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">image</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">"int"</span><span class="p">))</span> <span class="n">plt</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">"off"</span><span class="p">)</span> </code></pre></div> <p><img alt="png" src="/img/examples/vision/randaugment/randaugment_15_0.png" /></p> <p>You are encouraged to run the above code block a couple of times to see different variations.</p> <hr /> <h2 id="visualize-the-dataset-augmented-with-simpleaug">Visualize the dataset augmented with <code>simple_aug</code></h2> <div class="codehilite"><pre><span></span><code><span class="n">sample_images</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="nb">iter</span><span class="p">(</span><span class="n">train_ds_simple</span><span class="p">))</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">image</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">sample_images</span><span class="p">[:</span><span class="mi">9</span><span class="p">]):</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">image</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">"int"</span><span class="p">))</span> <span class="n">plt</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">"off"</span><span class="p">)</span> </code></pre></div> <p><img alt="png" src="/img/examples/vision/randaugment/randaugment_18_0.png" /></p> <hr /> <h2 id="define-a-model-building-utility-function">Define a model building utility function</h2> <p>Now, we define a CNN model that is based on the <a href="https://arxiv.org/abs/1603.05027">ResNet50V2 architecture</a>. Also, notice that the network already has a rescaling layer inside it. This eliminates the need to do any separate preprocessing on our dataset and is specifically very useful for deployment purposes.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">get_training_model</span><span class="p">():</span> <span class="n">resnet50_v2</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">applications</span><span class="o">.</span><span class="n">ResNet50V2</span><span class="p">(</span> <span class="n">weights</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">include_top</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">input_shape</span><span class="o">=</span><span class="p">(</span><span class="n">IMAGE_SIZE</span><span class="p">,</span> <span class="n">IMAGE_SIZE</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">classes</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span> <span class="p">[</span> <span class="n">layers</span><span class="o">.</span><span class="n">Input</span><span class="p">((</span><span class="n">IMAGE_SIZE</span><span class="p">,</span> <span class="n">IMAGE_SIZE</span><span class="p">,</span> <span class="mi">3</span><span class="p">)),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Rescaling</span><span class="p">(</span><span class="n">scale</span><span class="o">=</span><span class="mf">1.0</span> <span class="o">/</span> <span class="mf">127.5</span><span class="p">,</span> <span class="n">offset</span><span class="o">=-</span><span class="mi">1</span><span class="p">),</span> <span class="n">resnet50_v2</span><span class="p">,</span> <span class="p">]</span> <span class="p">)</span> <span class="k">return</span> <span class="n">model</span> <span class="n">get_training_model</span><span class="p">()</span><span class="o">.</span><span class="n">summary</span><span class="p">()</span> </code></pre></div> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold">Model: "sequential_1"</span> </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ ┃<span style="font-weight: bold"> Layer (type) </span>┃<span style="font-weight: bold"> Output Shape </span>┃<span style="font-weight: bold"> Param # </span>┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ │ rescaling (<span style="color: #0087ff; text-decoration-color: #0087ff">Rescaling</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">72</span>, <span style="color: #00af00; text-decoration-color: #00af00">72</span>, <span style="color: #00af00; text-decoration-color: #00af00">3</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ resnet50v2 (<span style="color: #0087ff; text-decoration-color: #0087ff">Functional</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">10</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">23,585,290</span> │ └─────────────────────────────────┴───────────────────────────┴────────────┘ </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Total params: </span><span style="color: #00af00; text-decoration-color: #00af00">23,585,290</span> (89.97 MB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">23,539,850</span> (89.80 MB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Non-trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">45,440</span> (177.50 KB) </pre> <p>We will train this network on two different versions of our dataset:</p> <ul> <li>One augmented with RandAugment.</li> <li>Another one augmented with <code>simple_aug</code>.</li> </ul> <p>Since RandAugment is known to enhance the robustness of models to common perturbations and corruptions, we will also evaluate our models on the CIFAR-10-C dataset, proposed in <a href="https://arxiv.org/abs/1903.12261">Benchmarking Neural Network Robustness to Common Corruptions and Perturbations</a> by Hendrycks et al. The CIFAR-10-C dataset consists of 19 different image corruptions and perturbations (for example speckle noise, fog, Gaussian blur, etc.) that too at varying severity levels. For this example we will be using the following configuration: <a href="https://www.tensorflow.org/datasets/catalog/cifar10_corrupted#cifar10_corruptedsaturate_5"><code>cifar10_corrupted/saturate_5</code></a>. The images from this configuration look like so:</p> <p><img alt="" src="https://storage.googleapis.com/tfds-data/visualization/fig/cifar10_corrupted-saturate_5-1.0.0.png" /></p> <p>In the interest of reproducibility, we serialize the initial random weights of our shallow network.</p> <div class="codehilite"><pre><span></span><code><span class="n">initial_model</span> <span class="o">=</span> <span class="n">get_training_model</span><span class="p">()</span> <span class="n">initial_model</span><span class="o">.</span><span class="n">save_weights</span><span class="p">(</span><span class="s2">"initial.weights.h5"</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="train-model-with-randaugment">Train model with RandAugment</h2> <div class="codehilite"><pre><span></span><code><span class="n">rand_aug_model</span> <span class="o">=</span> <span class="n">get_training_model</span><span class="p">()</span> <span class="n">rand_aug_model</span><span class="o">.</span><span class="n">load_weights</span><span class="p">(</span><span class="s2">"initial.weights.h5"</span><span class="p">)</span> <span class="n">rand_aug_model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">loss</span><span class="o">=</span><span class="s2">"sparse_categorical_crossentropy"</span><span class="p">,</span> <span class="n">optimizer</span><span class="o">=</span><span class="s2">"adam"</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="s2">"accuracy"</span><span class="p">]</span> <span class="p">)</span> <span class="n">rand_aug_model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">train_ds_rand</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="n">test_ds</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="n">EPOCHS</span><span class="p">)</span> <span class="n">_</span><span class="p">,</span> <span class="n">test_acc</span> <span class="o">=</span> <span class="n">rand_aug_model</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">test_ds</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Test accuracy: </span><span class="si">{:.2f}</span><span class="s2">%"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">test_acc</span> <span class="o">*</span> <span class="mi">100</span><span class="p">))</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 391/391 ━━━━━━━━━━━━━━━━━━━━ 1146s 3s/step - accuracy: 0.1677 - loss: 2.3232 - val_accuracy: 0.2818 - val_loss: 1.9966 79/79 ━━━━━━━━━━━━━━━━━━━━ 39s 489ms/step - accuracy: 0.2803 - loss: 2.0073 Test accuracy: 28.18% </code></pre></div> </div> <hr /> <h2 id="train-model-with-simpleaug">Train model with <code>simple_aug</code></h2> <div class="codehilite"><pre><span></span><code><span class="n">simple_aug_model</span> <span class="o">=</span> <span class="n">get_training_model</span><span class="p">()</span> <span class="n">simple_aug_model</span><span class="o">.</span><span class="n">load_weights</span><span class="p">(</span><span class="s2">"initial.weights.h5"</span><span class="p">)</span> <span class="n">simple_aug_model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">loss</span><span class="o">=</span><span class="s2">"sparse_categorical_crossentropy"</span><span class="p">,</span> <span class="n">optimizer</span><span class="o">=</span><span class="s2">"adam"</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="s2">"accuracy"</span><span class="p">]</span> <span class="p">)</span> <span class="n">simple_aug_model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">train_ds_simple</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="n">test_ds</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="n">EPOCHS</span><span class="p">)</span> <span class="n">_</span><span class="p">,</span> <span class="n">test_acc</span> <span class="o">=</span> <span class="n">simple_aug_model</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">test_ds</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Test accuracy: </span><span class="si">{:.2f}</span><span class="s2">%"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">test_acc</span> <span class="o">*</span> <span class="mi">100</span><span class="p">))</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 391/391 ━━━━━━━━━━━━━━━━━━━━ 1132s 3s/step - accuracy: 0.3673 - loss: 1.7929 - val_accuracy: 0.4789 - val_loss: 1.4296 79/79 ━━━━━━━━━━━━━━━━━━━━ 39s 494ms/step - accuracy: 0.4762 - loss: 1.4368 Test accuracy: 47.89% </code></pre></div> </div> <hr /> <h2 id="load-the-cifar10c-dataset-and-evaluate-performance">Load the CIFAR-10-C dataset and evaluate performance</h2> <div class="codehilite"><pre><span></span><code><span class="c1"># Load and prepare the CIFAR-10-C dataset</span> <span class="c1"># (If it's not already downloaded, it takes ~10 minutes of time to download)</span> <span class="n">cifar_10_c</span> <span class="o">=</span> <span class="n">tfds</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s2">"cifar10_corrupted/saturate_5"</span><span class="p">,</span> <span class="n">split</span><span class="o">=</span><span class="s2">"test"</span><span class="p">,</span> <span class="n">as_supervised</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="n">cifar_10_c</span> <span class="o">=</span> <span class="n">cifar_10_c</span><span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="n">BATCH_SIZE</span><span class="p">)</span><span class="o">.</span><span class="n">map</span><span class="p">(</span> <span class="k">lambda</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">resize</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="p">(</span><span class="n">IMAGE_SIZE</span><span class="p">,</span> <span class="n">IMAGE_SIZE</span><span class="p">)),</span> <span class="n">y</span><span class="p">),</span> <span class="n">num_parallel_calls</span><span class="o">=</span><span class="n">AUTO</span><span class="p">,</span> <span class="p">)</span> <span class="c1"># Evaluate `rand_aug_model`</span> <span class="n">_</span><span class="p">,</span> <span class="n">test_acc</span> <span class="o">=</span> <span class="n">rand_aug_model</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">cifar_10_c</span><span class="p">,</span> <span class="n">verbose</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span> <span class="s2">"Accuracy with RandAugment on CIFAR-10-C (saturate_5): </span><span class="si">{:.2f}</span><span class="s2">%"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span> <span class="n">test_acc</span> <span class="o">*</span> <span class="mi">100</span> <span class="p">)</span> <span class="p">)</span> <span class="c1"># Evaluate `simple_aug_model`</span> <span class="n">_</span><span class="p">,</span> <span class="n">test_acc</span> <span class="o">=</span> <span class="n">simple_aug_model</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">cifar_10_c</span><span class="p">,</span> <span class="n">verbose</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span> <span class="s2">"Accuracy with simple_aug on CIFAR-10-C (saturate_5): </span><span class="si">{:.2f}</span><span class="s2">%"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span> <span class="n">test_acc</span> <span class="o">*</span> <span class="mi">100</span> <span class="p">)</span> <span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> Downloading and preparing dataset 2.72 GiB (download: 2.72 GiB, generated: Unknown size, total: 2.72 GiB) to /home/sachinprasad/tensorflow_datasets/cifar10_corrupted/saturate_5/1.0.0... Dataset cifar10_corrupted downloaded and prepared to /home/sachinprasad/tensorflow_datasets/cifar10_corrupted/saturate_5/1.0.0. Subsequent calls will reuse this data. Accuracy with RandAugment on CIFAR-10-C (saturate_5): 30.36% Accuracy with simple_aug on CIFAR-10-C (saturate_5): 37.18% </code></pre></div> </div> <p>For the purpose of this example, we trained the models for only a single epoch. On the CIFAR-10-C dataset, the model with RandAugment can perform better with a higher accuracy (for example, 76.64% in one experiment) compared with the model trained with <code>simple_aug</code> (e.g., 64.80%). RandAugment can also help stabilize the training.</p> <p>In the notebook, you may notice that, at the expense of increased training time with RandAugment, we are able to carve out far better performance on the CIFAR-10-C dataset. You can experiment on the other corruption and perturbation settings that come with the run the same CIFAR-10-C dataset and see if RandAugment helps.</p> <p>You can also experiment with the different values of <code>n</code> and <code>m</code> in the <code>RandAugment</code> object. In the <a href="https://arxiv.org/abs/1909.13719">original paper</a>, the authors show the impact of the individual augmentation transforms for a particular task and a range of ablation studies. You are welcome to check them out.</p> <p>RandAugment has shown great progress in improving the robustness of deep models for computer vision as shown in works like <a href="https://arxiv.org/abs/1911.04252">Noisy Student Training</a> and <a href="https://arxiv.org/abs/2001.07685">FixMatch</a>. This makes RandAugment quite a useful recipe for training different vision models.</p> <p>You can use the trained model hosted on <a href="https://huggingface.co/keras-io/randaugment">Hugging Face Hub</a> and try the demo on <a href="https://huggingface.co/spaces/keras-io/randaugment">Hugging Face Spaces</a>.</p> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#randaugment-for-image-classification-for-improved-robustness'>RandAugment for Image Classification for Improved Robustness</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#imports-amp-setup'>Imports & setup</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#load-the-cifar10-dataset'>Load the CIFAR10 dataset</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#define-hyperparameters'>Define hyperparameters</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#initialize-randaugment-object'>Initialize <code>RandAugment</code> object</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#create-tensorflow-dataset-objects'>Create TensorFlow <code>Dataset</code> objects</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#visualize-the-dataset-augmented-with-randaugment'>Visualize the dataset augmented with RandAugment</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#visualize-the-dataset-augmented-with-simpleaug'>Visualize the dataset augmented with <code>simple_aug</code></a> </div> <div class='k-outline-depth-2'> ◆ <a href='#define-a-model-building-utility-function'>Define a model building utility function</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#train-model-with-randaugment'>Train model with RandAugment</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#train-model-with-simpleaug'>Train model with <code>simple_aug</code></a> </div> <div class='k-outline-depth-2'> ◆ <a href='#load-the-cifar10c-dataset-and-evaluate-performance'>Load the CIFAR-10-C dataset and evaluate performance</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>