CINXE.COM
Semi-supervised image classification using contrastive pretraining with SimCLR
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/vision/semisupervised_simclr/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Semi-supervised image classification using contrastive pretraining with SimCLR"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Semi-supervised image classification using contrastive pretraining with SimCLR"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Semi-supervised image classification using contrastive pretraining with SimCLR</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink active" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink2" href="/examples/vision/image_classification_from_scratch/">Image classification from scratch</a> <a class="nav-sublink2" href="/examples/vision/mnist_convnet/">Simple MNIST convnet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_efficientnet_fine_tuning/">Image classification via fine-tuning with EfficientNet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_with_vision_transformer/">Image classification with Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/attention_mil_classification/">Classification using Attention-based Deep Multiple Instance Learning</a> <a class="nav-sublink2" href="/examples/vision/mlp_image_classification/">Image classification with modern MLP models</a> <a class="nav-sublink2" href="/examples/vision/mobilevit/">A mobile-friendly Transformer-based model for image classification</a> <a class="nav-sublink2" href="/examples/vision/xray_classification_with_tpus/">Pneumonia Classification on TPU</a> <a class="nav-sublink2" href="/examples/vision/cct/">Compact Convolutional Transformers</a> <a class="nav-sublink2" href="/examples/vision/convmixer/">Image classification with ConvMixer</a> <a class="nav-sublink2" href="/examples/vision/eanet/">Image classification with EANet (External Attention Transformer)</a> <a class="nav-sublink2" href="/examples/vision/involution/">Involutional neural networks</a> <a class="nav-sublink2" href="/examples/vision/perceiver_image_classification/">Image classification with Perceiver</a> <a class="nav-sublink2" href="/examples/vision/reptile/">Few-Shot learning with Reptile</a> <a class="nav-sublink2 active" href="/examples/vision/semisupervised_simclr/">Semi-supervised image classification using contrastive pretraining with SimCLR</a> <a class="nav-sublink2" href="/examples/vision/swin_transformers/">Image classification with Swin Transformers</a> <a class="nav-sublink2" href="/examples/vision/vit_small_ds/">Train a Vision Transformer on small datasets</a> <a class="nav-sublink2" href="/examples/vision/shiftvit/">A Vision Transformer without Attention</a> <a class="nav-sublink2" href="/examples/vision/image_classification_using_global_context_vision_transformer/">Image Classification using Global Context Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/temporal_latent_bottleneck/">When Recurrence meets Transformers</a> <a class="nav-sublink2" href="/examples/vision/oxford_pets_image_segmentation/">Image segmentation with a U-Net-like architecture</a> <a class="nav-sublink2" href="/examples/vision/deeplabv3_plus/">Multiclass semantic segmentation using DeepLabV3+</a> <a class="nav-sublink2" href="/examples/vision/basnet_segmentation/">Highly accurate boundaries segmentation using BASNet</a> <a class="nav-sublink2" href="/examples/vision/fully_convolutional_network/">Image Segmentation using Composable Fully-Convolutional Networks</a> <a class="nav-sublink2" href="/examples/vision/retinanet/">Object Detection with RetinaNet</a> <a class="nav-sublink2" href="/examples/vision/keypoint_detection/">Keypoint Detection with Transfer Learning</a> <a class="nav-sublink2" href="/examples/vision/object_detection_using_vision_transformer/">Object detection with Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/3D_image_classification/">3D image classification from CT scans</a> <a class="nav-sublink2" href="/examples/vision/depth_estimation/">Monocular depth estimation</a> <a class="nav-sublink2" href="/examples/vision/nerf/">3D volumetric rendering with NeRF</a> <a class="nav-sublink2" href="/examples/vision/pointnet_segmentation/">Point cloud segmentation with PointNet</a> <a class="nav-sublink2" href="/examples/vision/pointnet/">Point cloud classification</a> <a class="nav-sublink2" href="/examples/vision/captcha_ocr/">OCR model for reading Captchas</a> <a class="nav-sublink2" href="/examples/vision/handwriting_recognition/">Handwriting recognition</a> <a class="nav-sublink2" href="/examples/vision/autoencoder/">Convolutional autoencoder for image denoising</a> <a class="nav-sublink2" href="/examples/vision/mirnet/">Low-light image enhancement using MIRNet</a> <a class="nav-sublink2" href="/examples/vision/super_resolution_sub_pixel/">Image Super-Resolution using an Efficient Sub-Pixel CNN</a> <a class="nav-sublink2" href="/examples/vision/edsr/">Enhanced Deep Residual Networks for single-image super-resolution</a> <a class="nav-sublink2" href="/examples/vision/zero_dce/">Zero-DCE for low-light image enhancement</a> <a class="nav-sublink2" href="/examples/vision/cutmix/">CutMix data augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/mixup/">MixUp augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/randaugment/">RandAugment for Image Classification for Improved Robustness</a> <a class="nav-sublink2" href="/examples/vision/image_captioning/">Image captioning</a> <a class="nav-sublink2" href="/examples/vision/nl_image_search/">Natural language image search with a Dual Encoder</a> <a class="nav-sublink2" href="/examples/vision/visualizing_what_convnets_learn/">Visualizing what convnets learn</a> <a class="nav-sublink2" href="/examples/vision/integrated_gradients/">Model interpretability with Integrated Gradients</a> <a class="nav-sublink2" href="/examples/vision/probing_vits/">Investigating Vision Transformer representations</a> <a class="nav-sublink2" href="/examples/vision/grad_cam/">Grad-CAM class activation visualization</a> <a class="nav-sublink2" href="/examples/vision/near_dup_search/">Near-duplicate image search</a> <a class="nav-sublink2" href="/examples/vision/semantic_image_clustering/">Semantic Image Clustering</a> <a class="nav-sublink2" href="/examples/vision/siamese_contrastive/">Image similarity estimation using a Siamese Network with a contrastive loss</a> <a class="nav-sublink2" href="/examples/vision/siamese_network/">Image similarity estimation using a Siamese Network with a triplet loss</a> <a class="nav-sublink2" href="/examples/vision/metric_learning/">Metric learning for image similarity search</a> <a class="nav-sublink2" href="/examples/vision/metric_learning_tf_similarity/">Metric learning for image similarity search using TensorFlow Similarity</a> <a class="nav-sublink2" href="/examples/vision/nnclr/">Self-supervised contrastive learning with NNCLR</a> <a class="nav-sublink2" href="/examples/vision/video_classification/">Video Classification with a CNN-RNN Architecture</a> <a class="nav-sublink2" href="/examples/vision/conv_lstm/">Next-Frame Video Prediction with Convolutional LSTMs</a> <a class="nav-sublink2" href="/examples/vision/video_transformers/">Video Classification with Transformers</a> <a class="nav-sublink2" href="/examples/vision/vivit/">Video Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/bit/">Image Classification using BigTransfer (BiT)</a> <a class="nav-sublink2" href="/examples/vision/gradient_centralization/">Gradient Centralization for Better Training Performance</a> <a class="nav-sublink2" href="/examples/vision/token_learner/">Learning to tokenize in Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/knowledge_distillation/">Knowledge Distillation</a> <a class="nav-sublink2" href="/examples/vision/fixres/">FixRes: Fixing train-test resolution discrepancy</a> <a class="nav-sublink2" href="/examples/vision/cait/">Class Attention Image Transformers with LayerScale</a> <a class="nav-sublink2" href="/examples/vision/patch_convnet/">Augmenting convnets with aggregated attention</a> <a class="nav-sublink2" href="/examples/vision/learnable_resizer/">Learning to Resize</a> <a class="nav-sublink2" href="/examples/vision/adamatch/">Semi-supervision and domain adaptation with AdaMatch</a> <a class="nav-sublink2" href="/examples/vision/barlow_twins/">Barlow Twins for Contrastive SSL</a> <a class="nav-sublink2" href="/examples/vision/consistency_training/">Consistency training with supervision</a> <a class="nav-sublink2" href="/examples/vision/deit/">Distilling Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/focal_modulation_network/">Focal Modulation: A replacement for Self-Attention</a> <a class="nav-sublink2" href="/examples/vision/forwardforward/">Using the Forward-Forward Algorithm for Image Classification</a> <a class="nav-sublink2" href="/examples/vision/masked_image_modeling/">Masked image modeling with Autoencoders</a> <a class="nav-sublink2" href="/examples/vision/sam/">Segment Anything Model with 🤗Transformers</a> <a class="nav-sublink2" href="/examples/vision/segformer/">Semantic segmentation with SegFormer and Hugging Face Transformers</a> <a class="nav-sublink2" href="/examples/vision/simsiam/">Self-supervised contrastive learning with SimSiam</a> <a class="nav-sublink2" href="/examples/vision/supervised-contrastive-learning/">Supervised Contrastive Learning</a> <a class="nav-sublink2" href="/examples/vision/yolov8/">Efficient Object Detection with YOLOV8 and KerasCV</a> <a class="nav-sublink" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparam Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/vision/'>Computer Vision</a> / Semi-supervised image classification using contrastive pretraining with SimCLR </div> <div class='k-content'> <h1 id="semisupervised-image-classification-using-contrastive-pretraining-with-simclr">Semi-supervised image classification using contrastive pretraining with SimCLR</h1> <p><strong>Author:</strong> <a href="https://www.linkedin.com/in/andras-beres-789190210">András Béres</a><br> <strong>Date created:</strong> 2021/04/24<br> <strong>Last modified:</strong> 2024/03/04<br> <strong>Description:</strong> Contrastive pretraining with SimCLR for semi-supervised image classification on the STL-10 dataset.</p> <div class='example_version_banner keras_3'>ⓘ This example uses Keras 3</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/vision/ipynb/semisupervised_simclr.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/examples/vision/semisupervised_simclr.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="introduction">Introduction</h2> <h3 id="semisupervised-learning">Semi-supervised learning</h3> <p>Semi-supervised learning is a machine learning paradigm that deals with <strong>partially labeled datasets</strong>. When applying deep learning in the real world, one usually has to gather a large dataset to make it work well. However, while the cost of labeling scales linearly with the dataset size (labeling each example takes a constant time), model performance only scales <a href="https://arxiv.org/abs/2001.08361">sublinearly</a> with it. This means that labeling more and more samples becomes less and less cost-efficient, while gathering unlabeled data is generally cheap, as it is usually readily available in large quantities.</p> <p>Semi-supervised learning offers to solve this problem by only requiring a partially labeled dataset, and by being label-efficient by utilizing the unlabeled examples for learning as well.</p> <p>In this example, we will pretrain an encoder with contrastive learning on the <a href="https://ai.stanford.edu/~acoates/stl10/">STL-10</a> semi-supervised dataset using no labels at all, and then fine-tune it using only its labeled subset.</p> <h3 id="contrastive-learning">Contrastive learning</h3> <p>On the highest level, the main idea behind contrastive learning is to <strong>learn representations that are invariant to image augmentations</strong> in a self-supervised manner. One problem with this objective is that it has a trivial degenerate solution: the case where the representations are constant, and do not depend at all on the input images.</p> <p>Contrastive learning avoids this trap by modifying the objective in the following way: it pulls representations of augmented versions/views of the same image closer to each other (contracting positives), while simultaneously pushing different images away from each other (contrasting negatives) in representation space.</p> <p>One such contrastive approach is <a href="https://arxiv.org/abs/2002.05709">SimCLR</a>, which essentially identifies the core components needed to optimize this objective, and can achieve high performance by scaling this simple approach.</p> <p>Another approach is <a href="https://arxiv.org/abs/2011.10566">SimSiam</a> (<a href="https://keras.io/examples/vision/simsiam/">Keras example</a>), whose main difference from SimCLR is that the former does not use any negatives in its loss. Therefore, it does not explicitly prevent the trivial solution, and, instead, avoids it implicitly by architecture design (asymmetric encoding paths using a predictor network and batch normalization (BatchNorm) are applied in the final layers).</p> <p>For further reading about SimCLR, check out <a href="https://ai.googleblog.com/2020/04/advancing-self-supervised-and-semi.html">the official Google AI blog post</a>, and for an overview of self-supervised learning across both vision and language check out <a href="https://ai.facebook.com/blog/self-supervised-learning-the-dark-matter-of-intelligence/">this blog post</a>.</p> <hr /> <h2 id="setup">Setup</h2> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">os</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s2">"KERAS_BACKEND"</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"tensorflow"</span> <span class="c1"># Make sure we are able to handle large datasets</span> <span class="kn">import</span> <span class="nn">resource</span> <span class="n">low</span><span class="p">,</span> <span class="n">high</span> <span class="o">=</span> <span class="n">resource</span><span class="o">.</span><span class="n">getrlimit</span><span class="p">(</span><span class="n">resource</span><span class="o">.</span><span class="n">RLIMIT_NOFILE</span><span class="p">)</span> <span class="n">resource</span><span class="o">.</span><span class="n">setrlimit</span><span class="p">(</span><span class="n">resource</span><span class="o">.</span><span class="n">RLIMIT_NOFILE</span><span class="p">,</span> <span class="p">(</span><span class="n">high</span><span class="p">,</span> <span class="n">high</span><span class="p">))</span> <span class="kn">import</span> <span class="nn">math</span> <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span> <span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="nn">tf</span> <span class="kn">import</span> <span class="nn">tensorflow_datasets</span> <span class="k">as</span> <span class="nn">tfds</span> <span class="kn">import</span> <span class="nn">keras</span> <span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">ops</span> <span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">layers</span> </code></pre></div> <hr /> <h2 id="hyperparameter-setup">Hyperparameter setup</h2> <div class="codehilite"><pre><span></span><code><span class="c1"># Dataset hyperparameters</span> <span class="n">unlabeled_dataset_size</span> <span class="o">=</span> <span class="mi">100000</span> <span class="n">labeled_dataset_size</span> <span class="o">=</span> <span class="mi">5000</span> <span class="n">image_channels</span> <span class="o">=</span> <span class="mi">3</span> <span class="c1"># Algorithm hyperparameters</span> <span class="n">num_epochs</span> <span class="o">=</span> <span class="mi">20</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="mi">525</span> <span class="c1"># Corresponds to 200 steps per epoch</span> <span class="n">width</span> <span class="o">=</span> <span class="mi">128</span> <span class="n">temperature</span> <span class="o">=</span> <span class="mf">0.1</span> <span class="c1"># Stronger augmentations for contrastive, weaker ones for supervised training</span> <span class="n">contrastive_augmentation</span> <span class="o">=</span> <span class="p">{</span><span class="s2">"min_area"</span><span class="p">:</span> <span class="mf">0.25</span><span class="p">,</span> <span class="s2">"brightness"</span><span class="p">:</span> <span class="mf">0.6</span><span class="p">,</span> <span class="s2">"jitter"</span><span class="p">:</span> <span class="mf">0.2</span><span class="p">}</span> <span class="n">classification_augmentation</span> <span class="o">=</span> <span class="p">{</span> <span class="s2">"min_area"</span><span class="p">:</span> <span class="mf">0.75</span><span class="p">,</span> <span class="s2">"brightness"</span><span class="p">:</span> <span class="mf">0.3</span><span class="p">,</span> <span class="s2">"jitter"</span><span class="p">:</span> <span class="mf">0.1</span><span class="p">,</span> <span class="p">}</span> </code></pre></div> <hr /> <h2 id="dataset">Dataset</h2> <p>During training we will simultaneously load a large batch of unlabeled images along with a smaller batch of labeled images.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">prepare_dataset</span><span class="p">():</span> <span class="c1"># Labeled and unlabeled samples are loaded synchronously</span> <span class="c1"># with batch sizes selected accordingly</span> <span class="n">steps_per_epoch</span> <span class="o">=</span> <span class="p">(</span><span class="n">unlabeled_dataset_size</span> <span class="o">+</span> <span class="n">labeled_dataset_size</span><span class="p">)</span> <span class="o">//</span> <span class="n">batch_size</span> <span class="n">unlabeled_batch_size</span> <span class="o">=</span> <span class="n">unlabeled_dataset_size</span> <span class="o">//</span> <span class="n">steps_per_epoch</span> <span class="n">labeled_batch_size</span> <span class="o">=</span> <span class="n">labeled_dataset_size</span> <span class="o">//</span> <span class="n">steps_per_epoch</span> <span class="nb">print</span><span class="p">(</span> <span class="sa">f</span><span class="s2">"batch size is </span><span class="si">{</span><span class="n">unlabeled_batch_size</span><span class="si">}</span><span class="s2"> (unlabeled) + </span><span class="si">{</span><span class="n">labeled_batch_size</span><span class="si">}</span><span class="s2"> (labeled)"</span> <span class="p">)</span> <span class="c1"># Turning off shuffle to lower resource usage</span> <span class="n">unlabeled_train_dataset</span> <span class="o">=</span> <span class="p">(</span> <span class="n">tfds</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s2">"stl10"</span><span class="p">,</span> <span class="n">split</span><span class="o">=</span><span class="s2">"unlabelled"</span><span class="p">,</span> <span class="n">as_supervised</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">shuffle_files</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> <span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="n">buffer_size</span><span class="o">=</span><span class="mi">10</span> <span class="o">*</span> <span class="n">unlabeled_batch_size</span><span class="p">)</span> <span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="n">unlabeled_batch_size</span><span class="p">)</span> <span class="p">)</span> <span class="n">labeled_train_dataset</span> <span class="o">=</span> <span class="p">(</span> <span class="n">tfds</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s2">"stl10"</span><span class="p">,</span> <span class="n">split</span><span class="o">=</span><span class="s2">"train"</span><span class="p">,</span> <span class="n">as_supervised</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">shuffle_files</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> <span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="n">buffer_size</span><span class="o">=</span><span class="mi">10</span> <span class="o">*</span> <span class="n">labeled_batch_size</span><span class="p">)</span> <span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="n">labeled_batch_size</span><span class="p">)</span> <span class="p">)</span> <span class="n">test_dataset</span> <span class="o">=</span> <span class="p">(</span> <span class="n">tfds</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="s2">"stl10"</span><span class="p">,</span> <span class="n">split</span><span class="o">=</span><span class="s2">"test"</span><span class="p">,</span> <span class="n">as_supervised</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span> <span class="o">.</span><span class="n">prefetch</span><span class="p">(</span><span class="n">buffer_size</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">AUTOTUNE</span><span class="p">)</span> <span class="p">)</span> <span class="c1"># Labeled and unlabeled datasets are zipped together</span> <span class="n">train_dataset</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">zip</span><span class="p">(</span> <span class="p">(</span><span class="n">unlabeled_train_dataset</span><span class="p">,</span> <span class="n">labeled_train_dataset</span><span class="p">)</span> <span class="p">)</span><span class="o">.</span><span class="n">prefetch</span><span class="p">(</span><span class="n">buffer_size</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">AUTOTUNE</span><span class="p">)</span> <span class="k">return</span> <span class="n">train_dataset</span><span class="p">,</span> <span class="n">labeled_train_dataset</span><span class="p">,</span> <span class="n">test_dataset</span> <span class="c1"># Load STL10 dataset</span> <span class="n">train_dataset</span><span class="p">,</span> <span class="n">labeled_train_dataset</span><span class="p">,</span> <span class="n">test_dataset</span> <span class="o">=</span> <span class="n">prepare_dataset</span><span class="p">()</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>batch size is 500 (unlabeled) + 25 (labeled) </code></pre></div> </div> <hr /> <h2 id="image-augmentations">Image augmentations</h2> <p>The two most important image augmentations for contrastive learning are the following:</p> <ul> <li>Cropping: forces the model to encode different parts of the same image similarly, we implement it with the <a href="https://keras.io/api/layers/preprocessing_layers/image_augmentation/random_translation/">RandomTranslation</a> and <a href="https://keras.io/api/layers/preprocessing_layers/image_augmentation/random_zoom/">RandomZoom</a> layers</li> <li>Color jitter: prevents a trivial color histogram-based solution to the task by distorting color histograms. A principled way to implement that is by affine transformations in color space.</li> </ul> <p>In this example we use random horizontal flips as well. Stronger augmentations are applied for contrastive learning, along with weaker ones for supervised classification to avoid overfitting on the few labeled examples.</p> <p>We implement random color jitter as a custom preprocessing layer. Using preprocessing layers for data augmentation has the following two advantages:</p> <ul> <li>The data augmentation will run on GPU in batches, so the training will not be bottlenecked by the data pipeline in environments with constrained CPU resources (such as a Colab Notebook, or a personal machine)</li> <li>Deployment is easier as the data preprocessing pipeline is encapsulated in the model, and does not have to be reimplemented when deploying it</li> </ul> <div class="codehilite"><pre><span></span><code><span class="c1"># Distorts the color distibutions of images</span> <span class="k">class</span> <span class="nc">RandomColorAffine</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">brightness</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">jitter</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">seed_generator</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">SeedGenerator</span><span class="p">(</span><span class="mi">1337</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">brightness</span> <span class="o">=</span> <span class="n">brightness</span> <span class="bp">self</span><span class="o">.</span><span class="n">jitter</span> <span class="o">=</span> <span class="n">jitter</span> <span class="k">def</span> <span class="nf">get_config</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="n">config</span> <span class="o">=</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">get_config</span><span class="p">()</span> <span class="n">config</span><span class="o">.</span><span class="n">update</span><span class="p">({</span><span class="s2">"brightness"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">brightness</span><span class="p">,</span> <span class="s2">"jitter"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">jitter</span><span class="p">})</span> <span class="k">return</span> <span class="n">config</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">images</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span> <span class="k">if</span> <span class="n">training</span><span class="p">:</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">images</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span> <span class="c1"># Same for all colors</span> <span class="n">brightness_scales</span> <span class="o">=</span> <span class="mi">1</span> <span class="o">+</span> <span class="n">keras</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">minval</span><span class="o">=-</span><span class="bp">self</span><span class="o">.</span><span class="n">brightness</span><span class="p">,</span> <span class="n">maxval</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">brightness</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">seed_generator</span><span class="p">,</span> <span class="p">)</span> <span class="c1"># Different for all colors</span> <span class="n">jitter_matrices</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">minval</span><span class="o">=-</span><span class="bp">self</span><span class="o">.</span><span class="n">jitter</span><span class="p">,</span> <span class="n">maxval</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">jitter</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">seed_generator</span><span class="p">,</span> <span class="p">)</span> <span class="n">color_transforms</span> <span class="o">=</span> <span class="p">(</span> <span class="n">ops</span><span class="o">.</span><span class="n">tile</span><span class="p">(</span><span class="n">ops</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">ops</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span><span class="mi">3</span><span class="p">),</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">),</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="o">*</span> <span class="n">brightness_scales</span> <span class="o">+</span> <span class="n">jitter_matrices</span> <span class="p">)</span> <span class="n">images</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">ops</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">images</span><span class="p">,</span> <span class="n">color_transforms</span><span class="p">),</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="k">return</span> <span class="n">images</span> <span class="c1"># Image augmentation module</span> <span class="k">def</span> <span class="nf">get_augmenter</span><span class="p">(</span><span class="n">min_area</span><span class="p">,</span> <span class="n">brightness</span><span class="p">,</span> <span class="n">jitter</span><span class="p">):</span> <span class="n">zoom_factor</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">-</span> <span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">min_area</span><span class="p">)</span> <span class="k">return</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span> <span class="p">[</span> <span class="n">layers</span><span class="o">.</span><span class="n">Rescaling</span><span class="p">(</span><span class="mi">1</span> <span class="o">/</span> <span class="mi">255</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">RandomFlip</span><span class="p">(</span><span class="s2">"horizontal"</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">RandomTranslation</span><span class="p">(</span><span class="n">zoom_factor</span> <span class="o">/</span> <span class="mi">2</span><span class="p">,</span> <span class="n">zoom_factor</span> <span class="o">/</span> <span class="mi">2</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">RandomZoom</span><span class="p">((</span><span class="o">-</span><span class="n">zoom_factor</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">),</span> <span class="p">(</span><span class="o">-</span><span class="n">zoom_factor</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">)),</span> <span class="n">RandomColorAffine</span><span class="p">(</span><span class="n">brightness</span><span class="p">,</span> <span class="n">jitter</span><span class="p">),</span> <span class="p">]</span> <span class="p">)</span> <span class="k">def</span> <span class="nf">visualize_augmentations</span><span class="p">(</span><span class="n">num_images</span><span class="p">):</span> <span class="c1"># Sample a batch from a dataset</span> <span class="n">images</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="nb">iter</span><span class="p">(</span><span class="n">train_dataset</span><span class="p">))[</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">][:</span><span class="n">num_images</span><span class="p">]</span> <span class="c1"># Apply augmentations</span> <span class="n">augmented_images</span> <span class="o">=</span> <span class="nb">zip</span><span class="p">(</span> <span class="n">images</span><span class="p">,</span> <span class="n">get_augmenter</span><span class="p">(</span><span class="o">**</span><span class="n">classification_augmentation</span><span class="p">)(</span><span class="n">images</span><span class="p">),</span> <span class="n">get_augmenter</span><span class="p">(</span><span class="o">**</span><span class="n">contrastive_augmentation</span><span class="p">)(</span><span class="n">images</span><span class="p">),</span> <span class="n">get_augmenter</span><span class="p">(</span><span class="o">**</span><span class="n">contrastive_augmentation</span><span class="p">)(</span><span class="n">images</span><span class="p">),</span> <span class="p">)</span> <span class="n">row_titles</span> <span class="o">=</span> <span class="p">[</span> <span class="s2">"Original:"</span><span class="p">,</span> <span class="s2">"Weakly augmented:"</span><span class="p">,</span> <span class="s2">"Strongly augmented:"</span><span class="p">,</span> <span class="s2">"Strongly augmented:"</span><span class="p">,</span> <span class="p">]</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="n">num_images</span> <span class="o">*</span> <span class="mf">2.2</span><span class="p">,</span> <span class="mi">4</span> <span class="o">*</span> <span class="mf">2.2</span><span class="p">),</span> <span class="n">dpi</span><span class="o">=</span><span class="mi">100</span><span class="p">)</span> <span class="k">for</span> <span class="n">column</span><span class="p">,</span> <span class="n">image_row</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">augmented_images</span><span class="p">):</span> <span class="k">for</span> <span class="n">row</span><span class="p">,</span> <span class="n">image</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">image_row</span><span class="p">):</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="n">num_images</span><span class="p">,</span> <span class="n">row</span> <span class="o">*</span> <span class="n">num_images</span> <span class="o">+</span> <span class="n">column</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">image</span><span class="p">)</span> <span class="k">if</span> <span class="n">column</span> <span class="o">==</span> <span class="mi">0</span><span class="p">:</span> <span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="n">row_titles</span><span class="p">[</span><span class="n">row</span><span class="p">],</span> <span class="n">loc</span><span class="o">=</span><span class="s2">"left"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">"off"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">tight_layout</span><span class="p">()</span> <span class="n">visualize_augmentations</span><span class="p">(</span><span class="n">num_images</span><span class="o">=</span><span class="mi">8</span><span class="p">)</span> </code></pre></div> <p><img alt="png" src="/img/examples/vision/semisupervised_simclr/semisupervised_simclr_9_0.png" /></p> <hr /> <h2 id="encoder-architecture">Encoder architecture</h2> <div class="codehilite"><pre><span></span><code><span class="c1"># Define the encoder architecture</span> <span class="k">def</span> <span class="nf">get_encoder</span><span class="p">():</span> <span class="k">return</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span> <span class="p">[</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="n">width</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">strides</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="n">width</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">strides</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="n">width</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">strides</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="n">width</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">strides</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Flatten</span><span class="p">(),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">width</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">),</span> <span class="p">],</span> <span class="n">name</span><span class="o">=</span><span class="s2">"encoder"</span><span class="p">,</span> <span class="p">)</span> </code></pre></div> <hr /> <h2 id="supervised-baseline-model">Supervised baseline model</h2> <p>A baseline supervised model is trained using random initialization.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Baseline supervised training with random initialization</span> <span class="n">baseline_model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span> <span class="p">[</span> <span class="n">get_augmenter</span><span class="p">(</span><span class="o">**</span><span class="n">classification_augmentation</span><span class="p">),</span> <span class="n">get_encoder</span><span class="p">(),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">10</span><span class="p">),</span> <span class="p">],</span> <span class="n">name</span><span class="o">=</span><span class="s2">"baseline_model"</span><span class="p">,</span> <span class="p">)</span> <span class="n">baseline_model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">Adam</span><span class="p">(),</span> <span class="n">loss</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">SparseCategoricalCrossentropy</span><span class="p">(</span><span class="n">from_logits</span><span class="o">=</span><span class="kc">True</span><span class="p">),</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">SparseCategoricalAccuracy</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">"acc"</span><span class="p">)],</span> <span class="p">)</span> <span class="n">baseline_history</span> <span class="o">=</span> <span class="n">baseline_model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span> <span class="n">labeled_train_dataset</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="n">num_epochs</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="n">test_dataset</span> <span class="p">)</span> <span class="nb">print</span><span class="p">(</span> <span class="s2">"Maximal validation accuracy: </span><span class="si">{:.2f}</span><span class="s2">%"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span> <span class="nb">max</span><span class="p">(</span><span class="n">baseline_history</span><span class="o">.</span><span class="n">history</span><span class="p">[</span><span class="s2">"val_acc"</span><span class="p">])</span> <span class="o">*</span> <span class="mi">100</span> <span class="p">)</span> <span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 1/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 9s 25ms/step - acc: 0.2031 - loss: 2.1576 - val_acc: 0.3234 - val_loss: 1.7719 Epoch 2/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.3476 - loss: 1.7792 - val_acc: 0.4042 - val_loss: 1.5626 Epoch 3/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.4060 - loss: 1.6054 - val_acc: 0.4319 - val_loss: 1.4832 Epoch 4/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 18ms/step - acc: 0.4347 - loss: 1.5052 - val_acc: 0.4570 - val_loss: 1.4428 Epoch 5/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 18ms/step - acc: 0.4600 - loss: 1.4546 - val_acc: 0.4765 - val_loss: 1.3977 Epoch 6/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.4754 - loss: 1.4015 - val_acc: 0.4740 - val_loss: 1.4082 Epoch 7/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.4901 - loss: 1.3589 - val_acc: 0.4761 - val_loss: 1.4061 Epoch 8/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.5110 - loss: 1.2793 - val_acc: 0.5247 - val_loss: 1.3026 Epoch 9/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.5298 - loss: 1.2765 - val_acc: 0.5138 - val_loss: 1.3286 Epoch 10/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.5514 - loss: 1.2078 - val_acc: 0.5543 - val_loss: 1.2227 Epoch 11/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.5520 - loss: 1.1851 - val_acc: 0.5446 - val_loss: 1.2709 Epoch 12/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.5851 - loss: 1.1368 - val_acc: 0.5725 - val_loss: 1.1944 Epoch 13/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 18ms/step - acc: 0.5738 - loss: 1.1411 - val_acc: 0.5685 - val_loss: 1.1974 Epoch 14/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 21ms/step - acc: 0.6078 - loss: 1.0308 - val_acc: 0.5899 - val_loss: 1.1769 Epoch 15/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 18ms/step - acc: 0.6284 - loss: 1.0386 - val_acc: 0.5863 - val_loss: 1.1742 Epoch 16/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 18ms/step - acc: 0.6450 - loss: 0.9773 - val_acc: 0.5849 - val_loss: 1.1993 Epoch 17/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.6547 - loss: 0.9555 - val_acc: 0.5683 - val_loss: 1.2424 Epoch 18/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.6593 - loss: 0.9084 - val_acc: 0.5990 - val_loss: 1.1458 Epoch 19/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.6672 - loss: 0.9267 - val_acc: 0.5685 - val_loss: 1.2758 Epoch 20/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.6824 - loss: 0.8863 - val_acc: 0.5969 - val_loss: 1.2035 Maximal validation accuracy: 59.90% </code></pre></div> </div> <hr /> <h2 id="selfsupervised-model-for-contrastive-pretraining">Self-supervised model for contrastive pretraining</h2> <p>We pretrain an encoder on unlabeled images with a contrastive loss. A nonlinear projection head is attached to the top of the encoder, as it improves the quality of representations of the encoder.</p> <p>We use the InfoNCE/NT-Xent/N-pairs loss, which can be interpreted in the following way:</p> <ol> <li>We treat each image in the batch as if it had its own class.</li> <li>Then, we have two examples (a pair of augmented views) for each "class".</li> <li>Each view's representation is compared to every possible pair's one (for both augmented versions).</li> <li>We use the temperature-scaled cosine similarity of compared representations as logits.</li> <li>Finally, we use categorical cross-entropy as the "classification" loss</li> </ol> <p>The following two metrics are used for monitoring the pretraining performance:</p> <ul> <li><a href="https://arxiv.org/abs/2002.05709">Contrastive accuracy (SimCLR Table 5)</a>: Self-supervised metric, the ratio of cases in which the representation of an image is more similar to its differently augmented version's one, than to the representation of any other image in the current batch. Self-supervised metrics can be used for hyperparameter tuning even in the case when there are no labeled examples.</li> <li><a href="https://arxiv.org/abs/1603.08511">Linear probing accuracy</a>: Linear probing is a popular metric to evaluate self-supervised classifiers. It is computed as the accuracy of a logistic regression classifier trained on top of the encoder's features. In our case, this is done by training a single dense layer on top of the frozen encoder. Note that contrary to traditional approach where the classifier is trained after the pretraining phase, in this example we train it during pretraining. This might slightly decrease its accuracy, but that way we can monitor its value during training, which helps with experimentation and debugging.</li> </ul> <p>Another widely used supervised metric is the <a href="https://arxiv.org/abs/1805.01978">KNN accuracy</a>, which is the accuracy of a KNN classifier trained on top of the encoder's features, which is not implemented in this example.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Define the contrastive model with model-subclassing</span> <span class="k">class</span> <span class="nc">ContrastiveModel</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">temperature</span> <span class="o">=</span> <span class="n">temperature</span> <span class="bp">self</span><span class="o">.</span><span class="n">contrastive_augmenter</span> <span class="o">=</span> <span class="n">get_augmenter</span><span class="p">(</span><span class="o">**</span><span class="n">contrastive_augmentation</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">classification_augmenter</span> <span class="o">=</span> <span class="n">get_augmenter</span><span class="p">(</span><span class="o">**</span><span class="n">classification_augmentation</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span> <span class="o">=</span> <span class="n">get_encoder</span><span class="p">()</span> <span class="c1"># Non-linear MLP as projection head</span> <span class="bp">self</span><span class="o">.</span><span class="n">projection_head</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span> <span class="p">[</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">width</span><span class="p">,)),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">width</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">width</span><span class="p">),</span> <span class="p">],</span> <span class="n">name</span><span class="o">=</span><span class="s2">"projection_head"</span><span class="p">,</span> <span class="p">)</span> <span class="c1"># Single dense layer for linear probing</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear_probe</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span> <span class="p">[</span><span class="n">layers</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">width</span><span class="p">,)),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">10</span><span class="p">)],</span> <span class="n">name</span><span class="o">=</span><span class="s2">"linear_probe"</span><span class="p">,</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="o">.</span><span class="n">summary</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">projection_head</span><span class="o">.</span><span class="n">summary</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear_probe</span><span class="o">.</span><span class="n">summary</span><span class="p">()</span> <span class="k">def</span> <span class="nf">compile</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">contrastive_optimizer</span><span class="p">,</span> <span class="n">probe_optimizer</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">contrastive_optimizer</span> <span class="o">=</span> <span class="n">contrastive_optimizer</span> <span class="bp">self</span><span class="o">.</span><span class="n">probe_optimizer</span> <span class="o">=</span> <span class="n">probe_optimizer</span> <span class="c1"># self.contrastive_loss will be defined as a method</span> <span class="bp">self</span><span class="o">.</span><span class="n">probe_loss</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">SparseCategoricalCrossentropy</span><span class="p">(</span><span class="n">from_logits</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">contrastive_loss_tracker</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">Mean</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">"c_loss"</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">contrastive_accuracy</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">SparseCategoricalAccuracy</span><span class="p">(</span> <span class="n">name</span><span class="o">=</span><span class="s2">"c_acc"</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">probe_loss_tracker</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">Mean</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">"p_loss"</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">probe_accuracy</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">SparseCategoricalAccuracy</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">"p_acc"</span><span class="p">)</span> <span class="nd">@property</span> <span class="k">def</span> <span class="nf">metrics</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="k">return</span> <span class="p">[</span> <span class="bp">self</span><span class="o">.</span><span class="n">contrastive_loss_tracker</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">contrastive_accuracy</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">probe_loss_tracker</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">probe_accuracy</span><span class="p">,</span> <span class="p">]</span> <span class="k">def</span> <span class="nf">contrastive_loss</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">projections_1</span><span class="p">,</span> <span class="n">projections_2</span><span class="p">):</span> <span class="c1"># InfoNCE loss (information noise-contrastive estimation)</span> <span class="c1"># NT-Xent loss (normalized temperature-scaled cross entropy)</span> <span class="c1"># Cosine similarity: the dot product of the l2-normalized feature vectors</span> <span class="n">projections_1</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">normalize</span><span class="p">(</span><span class="n">projections_1</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="n">projections_2</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">normalize</span><span class="p">(</span><span class="n">projections_2</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="n">similarities</span> <span class="o">=</span> <span class="p">(</span> <span class="n">ops</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">projections_1</span><span class="p">,</span> <span class="n">ops</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">projections_2</span><span class="p">))</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">temperature</span> <span class="p">)</span> <span class="c1"># The similarity between the representations of two augmented views of the</span> <span class="c1"># same image should be higher than their similarity with other views</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">projections_1</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span> <span class="n">contrastive_labels</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">contrastive_accuracy</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">contrastive_labels</span><span class="p">,</span> <span class="n">similarities</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">contrastive_accuracy</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span> <span class="n">contrastive_labels</span><span class="p">,</span> <span class="n">ops</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">similarities</span><span class="p">)</span> <span class="p">)</span> <span class="c1"># The temperature-scaled similarities are used as logits for cross-entropy</span> <span class="c1"># a symmetrized version of the loss is used here</span> <span class="n">loss_1_2</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">sparse_categorical_crossentropy</span><span class="p">(</span> <span class="n">contrastive_labels</span><span class="p">,</span> <span class="n">similarities</span><span class="p">,</span> <span class="n">from_logits</span><span class="o">=</span><span class="kc">True</span> <span class="p">)</span> <span class="n">loss_2_1</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">sparse_categorical_crossentropy</span><span class="p">(</span> <span class="n">contrastive_labels</span><span class="p">,</span> <span class="n">ops</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">similarities</span><span class="p">),</span> <span class="n">from_logits</span><span class="o">=</span><span class="kc">True</span> <span class="p">)</span> <span class="k">return</span> <span class="p">(</span><span class="n">loss_1_2</span> <span class="o">+</span> <span class="n">loss_2_1</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span> <span class="k">def</span> <span class="nf">train_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">):</span> <span class="p">(</span><span class="n">unlabeled_images</span><span class="p">,</span> <span class="n">_</span><span class="p">),</span> <span class="p">(</span><span class="n">labeled_images</span><span class="p">,</span> <span class="n">labels</span><span class="p">)</span> <span class="o">=</span> <span class="n">data</span> <span class="c1"># Both labeled and unlabeled images are used, without labels</span> <span class="n">images</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">concatenate</span><span class="p">((</span><span class="n">unlabeled_images</span><span class="p">,</span> <span class="n">labeled_images</span><span class="p">),</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="c1"># Each image is augmented twice, differently</span> <span class="n">augmented_images_1</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">contrastive_augmenter</span><span class="p">(</span><span class="n">images</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="n">augmented_images_2</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">contrastive_augmenter</span><span class="p">(</span><span class="n">images</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="k">with</span> <span class="n">tf</span><span class="o">.</span><span class="n">GradientTape</span><span class="p">()</span> <span class="k">as</span> <span class="n">tape</span><span class="p">:</span> <span class="n">features_1</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="p">(</span><span class="n">augmented_images_1</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="n">features_2</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="p">(</span><span class="n">augmented_images_2</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="c1"># The representations are passed through a projection mlp</span> <span class="n">projections_1</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">projection_head</span><span class="p">(</span><span class="n">features_1</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="n">projections_2</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">projection_head</span><span class="p">(</span><span class="n">features_2</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="n">contrastive_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">contrastive_loss</span><span class="p">(</span><span class="n">projections_1</span><span class="p">,</span> <span class="n">projections_2</span><span class="p">)</span> <span class="n">gradients</span> <span class="o">=</span> <span class="n">tape</span><span class="o">.</span><span class="n">gradient</span><span class="p">(</span> <span class="n">contrastive_loss</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="o">.</span><span class="n">trainable_weights</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">projection_head</span><span class="o">.</span><span class="n">trainable_weights</span><span class="p">,</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">contrastive_optimizer</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span> <span class="nb">zip</span><span class="p">(</span> <span class="n">gradients</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="o">.</span><span class="n">trainable_weights</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">projection_head</span><span class="o">.</span><span class="n">trainable_weights</span><span class="p">,</span> <span class="p">)</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">contrastive_loss_tracker</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">contrastive_loss</span><span class="p">)</span> <span class="c1"># Labels are only used in evalutation for an on-the-fly logistic regression</span> <span class="n">preprocessed_images</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">classification_augmenter</span><span class="p">(</span> <span class="n">labeled_images</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">True</span> <span class="p">)</span> <span class="k">with</span> <span class="n">tf</span><span class="o">.</span><span class="n">GradientTape</span><span class="p">()</span> <span class="k">as</span> <span class="n">tape</span><span class="p">:</span> <span class="c1"># the encoder is used in inference mode here to avoid regularization</span> <span class="c1"># and updating the batch normalization paramers if they are used</span> <span class="n">features</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="p">(</span><span class="n">preprocessed_images</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> <span class="n">class_logits</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear_probe</span><span class="p">(</span><span class="n">features</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="n">probe_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">probe_loss</span><span class="p">(</span><span class="n">labels</span><span class="p">,</span> <span class="n">class_logits</span><span class="p">)</span> <span class="n">gradients</span> <span class="o">=</span> <span class="n">tape</span><span class="o">.</span><span class="n">gradient</span><span class="p">(</span><span class="n">probe_loss</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear_probe</span><span class="o">.</span><span class="n">trainable_weights</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">probe_optimizer</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span> <span class="nb">zip</span><span class="p">(</span><span class="n">gradients</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear_probe</span><span class="o">.</span><span class="n">trainable_weights</span><span class="p">)</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">probe_loss_tracker</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">probe_loss</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">probe_accuracy</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">labels</span><span class="p">,</span> <span class="n">class_logits</span><span class="p">)</span> <span class="k">return</span> <span class="p">{</span><span class="n">m</span><span class="o">.</span><span class="n">name</span><span class="p">:</span> <span class="n">m</span><span class="o">.</span><span class="n">result</span><span class="p">()</span> <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">metrics</span><span class="p">}</span> <span class="k">def</span> <span class="nf">test_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">):</span> <span class="n">labeled_images</span><span class="p">,</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">data</span> <span class="c1"># For testing the components are used with a training=False flag</span> <span class="n">preprocessed_images</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">classification_augmenter</span><span class="p">(</span> <span class="n">labeled_images</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">False</span> <span class="p">)</span> <span class="n">features</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="p">(</span><span class="n">preprocessed_images</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> <span class="n">class_logits</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear_probe</span><span class="p">(</span><span class="n">features</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> <span class="n">probe_loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">probe_loss</span><span class="p">(</span><span class="n">labels</span><span class="p">,</span> <span class="n">class_logits</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">probe_loss_tracker</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">probe_loss</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">probe_accuracy</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">labels</span><span class="p">,</span> <span class="n">class_logits</span><span class="p">)</span> <span class="c1"># Only the probe metrics are logged at test time</span> <span class="k">return</span> <span class="p">{</span><span class="n">m</span><span class="o">.</span><span class="n">name</span><span class="p">:</span> <span class="n">m</span><span class="o">.</span><span class="n">result</span><span class="p">()</span> <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">metrics</span><span class="p">[</span><span class="mi">2</span><span class="p">:]}</span> <span class="c1"># Contrastive pretraining</span> <span class="n">pretraining_model</span> <span class="o">=</span> <span class="n">ContrastiveModel</span><span class="p">()</span> <span class="n">pretraining_model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">contrastive_optimizer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">Adam</span><span class="p">(),</span> <span class="n">probe_optimizer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">Adam</span><span class="p">(),</span> <span class="p">)</span> <span class="n">pretraining_history</span> <span class="o">=</span> <span class="n">pretraining_model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span> <span class="n">train_dataset</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="n">num_epochs</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="n">test_dataset</span> <span class="p">)</span> <span class="nb">print</span><span class="p">(</span> <span class="s2">"Maximal validation accuracy: </span><span class="si">{:.2f}</span><span class="s2">%"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span> <span class="nb">max</span><span class="p">(</span><span class="n">pretraining_history</span><span class="o">.</span><span class="n">history</span><span class="p">[</span><span class="s2">"val_p_acc"</span><span class="p">])</span> <span class="o">*</span> <span class="mi">100</span> <span class="p">)</span> <span class="p">)</span> </code></pre></div> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold">Model: "encoder"</span> </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ ┃<span style="font-weight: bold"> Layer (type) </span>┃<span style="font-weight: bold"> Output Shape </span>┃<span style="font-weight: bold"> Param # </span>┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ │ conv2d_4 (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv2D</span>) │ ? │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ │ │ │ (unbuilt) │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_5 (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv2D</span>) │ ? │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ │ │ │ (unbuilt) │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_6 (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv2D</span>) │ ? │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ │ │ │ (unbuilt) │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ conv2d_7 (<span style="color: #0087ff; text-decoration-color: #0087ff">Conv2D</span>) │ ? │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ │ │ │ (unbuilt) │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ flatten_1 (<span style="color: #0087ff; text-decoration-color: #0087ff">Flatten</span>) │ ? │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ │ │ │ (unbuilt) │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ dense_2 (<span style="color: #0087ff; text-decoration-color: #0087ff">Dense</span>) │ ? │ <span style="color: #00af00; text-decoration-color: #00af00">0</span> │ │ │ │ (unbuilt) │ └─────────────────────────────────┴───────────────────────────┴────────────┘ </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Total params: </span><span style="color: #00af00; text-decoration-color: #00af00">0</span> (0.00 B) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">0</span> (0.00 B) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Non-trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">0</span> (0.00 B) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold">Model: "projection_head"</span> </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ ┃<span style="font-weight: bold"> Layer (type) </span>┃<span style="font-weight: bold"> Output Shape </span>┃<span style="font-weight: bold"> Param # </span>┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ │ dense_3 (<span style="color: #0087ff; text-decoration-color: #0087ff">Dense</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">16,512</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ dense_4 (<span style="color: #0087ff; text-decoration-color: #0087ff">Dense</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">128</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">16,512</span> │ └─────────────────────────────────┴───────────────────────────┴────────────┘ </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Total params: </span><span style="color: #00af00; text-decoration-color: #00af00">33,024</span> (129.00 KB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">33,024</span> (129.00 KB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Non-trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">0</span> (0.00 B) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold">Model: "linear_probe"</span> </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ ┃<span style="font-weight: bold"> Layer (type) </span>┃<span style="font-weight: bold"> Output Shape </span>┃<span style="font-weight: bold"> Param # </span>┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ │ dense_5 (<span style="color: #0087ff; text-decoration-color: #0087ff">Dense</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">10</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">1,290</span> │ └─────────────────────────────────┴───────────────────────────┴────────────┘ </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Total params: </span><span style="color: #00af00; text-decoration-color: #00af00">1,290</span> (5.04 KB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">1,290</span> (5.04 KB) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Non-trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">0</span> (0.00 B) </pre> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 1/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 34s 134ms/step - c_acc: 0.0880 - c_loss: 5.2606 - p_acc: 0.1326 - p_loss: 2.2726 - val_c_acc: 0.0000e+00 - val_c_loss: 0.0000e+00 - val_p_acc: 0.2579 - val_p_loss: 2.0671 Epoch 2/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 29s 139ms/step - c_acc: 0.2808 - c_loss: 3.6233 - p_acc: 0.2956 - p_loss: 2.0228 - val_c_acc: 0.0000e+00 - val_c_loss: 0.0000e+00 - val_p_acc: 0.3440 - val_p_loss: 1.9242 Epoch 3/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 28s 136ms/step - c_acc: 0.4097 - c_loss: 2.9369 - p_acc: 0.3671 - p_loss: 1.8674 - val_c_acc: 0.0000e+00 - val_c_loss: 0.0000e+00 - val_p_acc: 0.3876 - val_p_loss: 1.7757 Epoch 4/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 30s 142ms/step - c_acc: 0.4893 - c_loss: 2.5707 - p_acc: 0.3957 - p_loss: 1.7490 - val_c_acc: 0.0000e+00 - val_c_loss: 0.0000e+00 - val_p_acc: 0.3960 - val_p_loss: 1.7002 Epoch 5/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 28s 136ms/step - c_acc: 0.5458 - c_loss: 2.3342 - p_acc: 0.4274 - p_loss: 1.6608 - val_c_acc: 0.0000e+00 - val_c_loss: 0.0000e+00 - val_p_acc: 0.4374 - val_p_loss: 1.6145 Epoch 6/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 29s 140ms/step - c_acc: 0.5949 - c_loss: 2.1179 - p_acc: 0.4410 - p_loss: 1.5812 - val_c_acc: 0.0000e+00 - val_c_loss: 0.0000e+00 - val_p_acc: 0.4444 - val_p_loss: 1.5439 Epoch 7/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 28s 135ms/step - c_acc: 0.6273 - c_loss: 1.9861 - p_acc: 0.4633 - p_loss: 1.5076 - val_c_acc: 0.0000e+00 - val_c_loss: 0.0000e+00 - val_p_acc: 0.4695 - val_p_loss: 1.5056 Epoch 8/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 29s 139ms/step - c_acc: 0.6566 - c_loss: 1.8668 - p_acc: 0.4817 - p_loss: 1.4601 - val_c_acc: 0.0000e+00 - val_c_loss: 0.0000e+00 - val_p_acc: 0.4790 - val_p_loss: 1.4566 Epoch 9/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 28s 135ms/step - c_acc: 0.6726 - c_loss: 1.7938 - p_acc: 0.4885 - p_loss: 1.4136 - val_c_acc: 0.0000e+00 - val_c_loss: 0.0000e+00 - val_p_acc: 0.4933 - val_p_loss: 1.4163 Epoch 10/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 29s 139ms/step - c_acc: 0.6931 - c_loss: 1.7210 - p_acc: 0.4954 - p_loss: 1.3663 - val_c_acc: 0.0000e+00 - val_c_loss: 0.0000e+00 - val_p_acc: 0.5140 - val_p_loss: 1.3677 Epoch 11/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 29s 137ms/step - c_acc: 0.7055 - c_loss: 1.6619 - p_acc: 0.5210 - p_loss: 1.3376 - val_c_acc: 0.0000e+00 - val_c_loss: 0.0000e+00 - val_p_acc: 0.5155 - val_p_loss: 1.3573 Epoch 12/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 30s 145ms/step - c_acc: 0.7215 - c_loss: 1.6112 - p_acc: 0.5264 - p_loss: 1.2920 - val_c_acc: 0.0000e+00 - val_c_loss: 0.0000e+00 - val_p_acc: 0.5232 - val_p_loss: 1.3337 Epoch 13/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 31s 146ms/step - c_acc: 0.7279 - c_loss: 1.5749 - p_acc: 0.5388 - p_loss: 1.2570 - val_c_acc: 0.0000e+00 - val_c_loss: 0.0000e+00 - val_p_acc: 0.5217 - val_p_loss: 1.3155 Epoch 14/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 29s 140ms/step - c_acc: 0.7435 - c_loss: 1.5196 - p_acc: 0.5505 - p_loss: 1.2507 - val_c_acc: 0.0000e+00 - val_c_loss: 0.0000e+00 - val_p_acc: 0.5460 - val_p_loss: 1.2640 Epoch 15/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 40s 135ms/step - c_acc: 0.7477 - c_loss: 1.4979 - p_acc: 0.5653 - p_loss: 1.2188 - val_c_acc: 0.0000e+00 - val_c_loss: 0.0000e+00 - val_p_acc: 0.5594 - val_p_loss: 1.2351 Epoch 16/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 29s 139ms/step - c_acc: 0.7598 - c_loss: 1.4463 - p_acc: 0.5590 - p_loss: 1.1917 - val_c_acc: 0.0000e+00 - val_c_loss: 0.0000e+00 - val_p_acc: 0.5551 - val_p_loss: 1.2411 Epoch 17/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 28s 135ms/step - c_acc: 0.7633 - c_loss: 1.4271 - p_acc: 0.5775 - p_loss: 1.1731 - val_c_acc: 0.0000e+00 - val_c_loss: 0.0000e+00 - val_p_acc: 0.5502 - val_p_loss: 1.2428 Epoch 18/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 29s 140ms/step - c_acc: 0.7666 - c_loss: 1.4246 - p_acc: 0.5752 - p_loss: 1.1805 - val_c_acc: 0.0000e+00 - val_c_loss: 0.0000e+00 - val_p_acc: 0.5633 - val_p_loss: 1.2167 Epoch 19/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 28s 135ms/step - c_acc: 0.7708 - c_loss: 1.3928 - p_acc: 0.5814 - p_loss: 1.1677 - val_c_acc: 0.0000e+00 - val_c_loss: 0.0000e+00 - val_p_acc: 0.5665 - val_p_loss: 1.2191 Epoch 20/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 29s 140ms/step - c_acc: 0.7806 - c_loss: 1.3733 - p_acc: 0.5836 - p_loss: 1.1442 - val_c_acc: 0.0000e+00 - val_c_loss: 0.0000e+00 - val_p_acc: 0.5640 - val_p_loss: 1.2172 Maximal validation accuracy: 56.65% </code></pre></div> </div> <hr /> <h2 id="supervised-finetuning-of-the-pretrained-encoder">Supervised finetuning of the pretrained encoder</h2> <p>We then finetune the encoder on the labeled examples, by attaching a single randomly initalized fully connected classification layer on its top.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Supervised finetuning of the pretrained encoder</span> <span class="n">finetuning_model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span> <span class="p">[</span> <span class="n">get_augmenter</span><span class="p">(</span><span class="o">**</span><span class="n">classification_augmentation</span><span class="p">),</span> <span class="n">pretraining_model</span><span class="o">.</span><span class="n">encoder</span><span class="p">,</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">10</span><span class="p">),</span> <span class="p">],</span> <span class="n">name</span><span class="o">=</span><span class="s2">"finetuning_model"</span><span class="p">,</span> <span class="p">)</span> <span class="n">finetuning_model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">Adam</span><span class="p">(),</span> <span class="n">loss</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">SparseCategoricalCrossentropy</span><span class="p">(</span><span class="n">from_logits</span><span class="o">=</span><span class="kc">True</span><span class="p">),</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">SparseCategoricalAccuracy</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">"acc"</span><span class="p">)],</span> <span class="p">)</span> <span class="n">finetuning_history</span> <span class="o">=</span> <span class="n">finetuning_model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span> <span class="n">labeled_train_dataset</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="n">num_epochs</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="n">test_dataset</span> <span class="p">)</span> <span class="nb">print</span><span class="p">(</span> <span class="s2">"Maximal validation accuracy: </span><span class="si">{:.2f}</span><span class="s2">%"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span> <span class="nb">max</span><span class="p">(</span><span class="n">finetuning_history</span><span class="o">.</span><span class="n">history</span><span class="p">[</span><span class="s2">"val_acc"</span><span class="p">])</span> <span class="o">*</span> <span class="mi">100</span> <span class="p">)</span> <span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 1/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 5s 18ms/step - acc: 0.2104 - loss: 2.0930 - val_acc: 0.4017 - val_loss: 1.5433 Epoch 2/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.4037 - loss: 1.5791 - val_acc: 0.4544 - val_loss: 1.4250 Epoch 3/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.4639 - loss: 1.4161 - val_acc: 0.5266 - val_loss: 1.2958 Epoch 4/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.5438 - loss: 1.2686 - val_acc: 0.5655 - val_loss: 1.1711 Epoch 5/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.5678 - loss: 1.1746 - val_acc: 0.5775 - val_loss: 1.1670 Epoch 6/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.6096 - loss: 1.1071 - val_acc: 0.6034 - val_loss: 1.1400 Epoch 7/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.6242 - loss: 1.0413 - val_acc: 0.6235 - val_loss: 1.0756 Epoch 8/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.6284 - loss: 1.0264 - val_acc: 0.6030 - val_loss: 1.1048 Epoch 9/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.6491 - loss: 0.9706 - val_acc: 0.5770 - val_loss: 1.2818 Epoch 10/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.6754 - loss: 0.9104 - val_acc: 0.6119 - val_loss: 1.1087 Epoch 11/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 20ms/step - acc: 0.6620 - loss: 0.8855 - val_acc: 0.6323 - val_loss: 1.0526 Epoch 12/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 19ms/step - acc: 0.7060 - loss: 0.8179 - val_acc: 0.6406 - val_loss: 1.0565 Epoch 13/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - acc: 0.7252 - loss: 0.7796 - val_acc: 0.6135 - val_loss: 1.1273 Epoch 14/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.7176 - loss: 0.7935 - val_acc: 0.6292 - val_loss: 1.1028 Epoch 15/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.7322 - loss: 0.7471 - val_acc: 0.6266 - val_loss: 1.1313 Epoch 16/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.7400 - loss: 0.7218 - val_acc: 0.6332 - val_loss: 1.1064 Epoch 17/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.7490 - loss: 0.6968 - val_acc: 0.6532 - val_loss: 1.0112 Epoch 18/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.7491 - loss: 0.6879 - val_acc: 0.6403 - val_loss: 1.1083 Epoch 19/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 4s 17ms/step - acc: 0.7802 - loss: 0.6504 - val_acc: 0.6479 - val_loss: 1.0548 Epoch 20/20 200/200 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - acc: 0.7800 - loss: 0.6234 - val_acc: 0.6409 - val_loss: 1.0998 Maximal validation accuracy: 65.32% </code></pre></div> </div> <hr /> <h2 id="comparison-against-the-baseline">Comparison against the baseline</h2> <div class="codehilite"><pre><span></span><code><span class="c1"># The classification accuracies of the baseline and the pretraining + finetuning process:</span> <span class="k">def</span> <span class="nf">plot_training_curves</span><span class="p">(</span><span class="n">pretraining_history</span><span class="p">,</span> <span class="n">finetuning_history</span><span class="p">,</span> <span class="n">baseline_history</span><span class="p">):</span> <span class="k">for</span> <span class="n">metric_key</span><span class="p">,</span> <span class="n">metric_name</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">([</span><span class="s2">"acc"</span><span class="p">,</span> <span class="s2">"loss"</span><span class="p">],</span> <span class="p">[</span><span class="s2">"accuracy"</span><span class="p">,</span> <span class="s2">"loss"</span><span class="p">]):</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">8</span><span class="p">,</span> <span class="mi">5</span><span class="p">),</span> <span class="n">dpi</span><span class="o">=</span><span class="mi">100</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span> <span class="n">baseline_history</span><span class="o">.</span><span class="n">history</span><span class="p">[</span><span class="sa">f</span><span class="s2">"val_</span><span class="si">{</span><span class="n">metric_key</span><span class="si">}</span><span class="s2">"</span><span class="p">],</span> <span class="n">label</span><span class="o">=</span><span class="s2">"supervised baseline"</span><span class="p">,</span> <span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span> <span class="n">pretraining_history</span><span class="o">.</span><span class="n">history</span><span class="p">[</span><span class="sa">f</span><span class="s2">"val_p_</span><span class="si">{</span><span class="n">metric_key</span><span class="si">}</span><span class="s2">"</span><span class="p">],</span> <span class="n">label</span><span class="o">=</span><span class="s2">"self-supervised pretraining"</span><span class="p">,</span> <span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span> <span class="n">finetuning_history</span><span class="o">.</span><span class="n">history</span><span class="p">[</span><span class="sa">f</span><span class="s2">"val_</span><span class="si">{</span><span class="n">metric_key</span><span class="si">}</span><span class="s2">"</span><span class="p">],</span> <span class="n">label</span><span class="o">=</span><span class="s2">"supervised finetuning"</span><span class="p">,</span> <span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">legend</span><span class="p">()</span> <span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Classification </span><span class="si">{</span><span class="n">metric_name</span><span class="si">}</span><span class="s2"> during training"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s2">"epochs"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">ylabel</span><span class="p">(</span><span class="sa">f</span><span class="s2">"validation </span><span class="si">{</span><span class="n">metric_name</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> <span class="n">plot_training_curves</span><span class="p">(</span><span class="n">pretraining_history</span><span class="p">,</span> <span class="n">finetuning_history</span><span class="p">,</span> <span class="n">baseline_history</span><span class="p">)</span> </code></pre></div> <p><img alt="png" src="/img/examples/vision/semisupervised_simclr/semisupervised_simclr_19_0.png" /></p> <p><img alt="png" src="/img/examples/vision/semisupervised_simclr/semisupervised_simclr_19_1.png" /></p> <p>By comparing the training curves, we can see that when using contrastive pretraining, a higher validation accuracy can be reached, paired with a lower validation loss, which means that the pretrained network was able to generalize better when seeing only a small amount of labeled examples.</p> <hr /> <h2 id="improving-further">Improving further</h2> <h3 id="architecture">Architecture</h3> <p>The experiment in the original paper demonstrated that increasing the width and depth of the models improves performance at a higher rate than for supervised learning. Also, using a <a href="https://keras.io/api/applications/resnet/#resnet50-function">ResNet-50</a> encoder is quite standard in the literature. However keep in mind, that more powerful models will not only increase training time but will also require more memory and will limit the maximal batch size you can use.</p> <p>It has <a href="https://arxiv.org/abs/1905.09272">been</a> <a href="https://arxiv.org/abs/1911.05722">reported</a> that the usage of BatchNorm layers could sometimes degrade performance, as it introduces an intra-batch dependency between samples, which is why I did not have used them in this example. In my experiments however, using BatchNorm, especially in the projection head, improves performance.</p> <h3 id="hyperparameters">Hyperparameters</h3> <p>The hyperparameters used in this example have been tuned manually for this task and architecture. Therefore, without changing them, only marginal gains can be expected from further hyperparameter tuning.</p> <p>However for a different task or model architecture these would need tuning, so here are my notes on the most important ones:</p> <ul> <li><strong>Batch size</strong>: since the objective can be interpreted as a classification over a batch of images (loosely speaking), the batch size is actually a more important hyperparameter than usual. The higher, the better.</li> <li><strong>Temperature</strong>: the temperature defines the "softness" of the softmax distribution that is used in the cross-entropy loss, and is an important hyperparameter. Lower values generally lead to a higher contrastive accuracy. A recent trick (in <a href="https://arxiv.org/abs/2102.05918">ALIGN</a>) is to learn the temperature's value as well (which can be done by defining it as a tf.Variable, and applying gradients on it). Even though this provides a good baseline value, in my experiments the learned temperature was somewhat lower than optimal, as it is optimized with respect to the contrastive loss, which is not a perfect proxy for representation quality.</li> <li><strong>Image augmentation strength</strong>: during pretraining stronger augmentations increase the difficulty of the task, however after a point too strong augmentations will degrade performance. During finetuning stronger augmentations reduce overfitting while in my experience too strong augmentations decrease the performance gains from pretraining. The whole data augmentation pipeline can be seen as an important hyperparameter of the algorithm, implementations of other custom image augmentation layers in Keras can be found in <a href="https://github.com/beresandras/image-augmentation-layers-keras">this repository</a>.</li> <li><strong>Learning rate schedule</strong>: a constant schedule is used here, but it is quite common in the literature to use a <a href="https://www.tensorflow.org/api_docs/python/tf/keras/experimental/CosineDecay">cosine decay schedule</a>, which can further improve performance.</li> <li><strong>Optimizer</strong>: Adam is used in this example, as it provides good performance with default parameters. SGD with momentum requires more tuning, however it could slightly increase performance.</li> </ul> <hr /> <h2 id="related-works">Related works</h2> <p>Other instance-level (image-level) contrastive learning methods:</p> <ul> <li><a href="https://arxiv.org/abs/1911.05722">MoCo</a> (<a href="https://arxiv.org/abs/2003.04297">v2</a>, <a href="https://arxiv.org/abs/2104.02057">v3</a>): uses a momentum-encoder as well, whose weights are an exponential moving average of the target encoder</li> <li><a href="https://arxiv.org/abs/2006.09882">SwAV</a>: uses clustering instead of pairwise comparison</li> <li><a href="https://arxiv.org/abs/2103.03230">BarlowTwins</a>: uses a cross correlation-based objective instead of pairwise comparison</li> </ul> <p>Keras implementations of <strong>MoCo</strong> and <strong>BarlowTwins</strong> can be found in <a href="https://github.com/beresandras/contrastive-classification-keras">this repository</a>, which includes a Colab notebook.</p> <p>There is also a new line of works, which optimize a similar objective, but without the use of any negatives:</p> <ul> <li><a href="https://arxiv.org/abs/2006.07733">BYOL</a>: momentum-encoder + no negatives</li> <li><a href="https://arxiv.org/abs/2011.10566">SimSiam</a> (<a href="https://keras.io/examples/vision/simsiam/">Keras example</a>): no momentum-encoder + no negatives</li> </ul> <p>In my experience, these methods are more brittle (they can collapse to a constant representation, I could not get them to work using this encoder architecture). Even though they are generally more dependent on the <a href="https://generallyintelligent.ai/understanding-self-supervised-contrastive-learning.html">model</a> <a href="https://arxiv.org/abs/2010.10241">architecture</a>, they can improve performance at smaller batch sizes.</p> <p>You can use the trained model hosted on <a href="https://huggingface.co/keras-io/semi-supervised-classification-simclr">Hugging Face Hub</a> and try the demo on <a href="https://huggingface.co/spaces/keras-io/semi-supervised-classification">Hugging Face Spaces</a>.</p> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#semisupervised-image-classification-using-contrastive-pretraining-with-simclr'>Semi-supervised image classification using contrastive pretraining with SimCLR</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-3'> <a href='#semisupervised-learning'>Semi-supervised learning</a> </div> <div class='k-outline-depth-3'> <a href='#contrastive-learning'>Contrastive learning</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#setup'>Setup</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#hyperparameter-setup'>Hyperparameter setup</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#dataset'>Dataset</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#image-augmentations'>Image augmentations</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#encoder-architecture'>Encoder architecture</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#supervised-baseline-model'>Supervised baseline model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#selfsupervised-model-for-contrastive-pretraining'>Self-supervised model for contrastive pretraining</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#supervised-finetuning-of-the-pretrained-encoder'>Supervised finetuning of the pretrained encoder</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#comparison-against-the-baseline'>Comparison against the baseline</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#improving-further'>Improving further</a> </div> <div class='k-outline-depth-3'> <a href='#architecture'>Architecture</a> </div> <div class='k-outline-depth-3'> <a href='#hyperparameters'>Hyperparameters</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#related-works'>Related works</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>