CINXE.COM
Self-supervised contrastive learning with SimSiam
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/vision/simsiam/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Self-supervised contrastive learning with SimSiam"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Self-supervised contrastive learning with SimSiam"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Self-supervised contrastive learning with SimSiam</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink active" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink2" href="/examples/vision/image_classification_from_scratch/">Image classification from scratch</a> <a class="nav-sublink2" href="/examples/vision/mnist_convnet/">Simple MNIST convnet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_efficientnet_fine_tuning/">Image classification via fine-tuning with EfficientNet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_with_vision_transformer/">Image classification with Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/attention_mil_classification/">Classification using Attention-based Deep Multiple Instance Learning</a> <a class="nav-sublink2" href="/examples/vision/mlp_image_classification/">Image classification with modern MLP models</a> <a class="nav-sublink2" href="/examples/vision/mobilevit/">A mobile-friendly Transformer-based model for image classification</a> <a class="nav-sublink2" href="/examples/vision/xray_classification_with_tpus/">Pneumonia Classification on TPU</a> <a class="nav-sublink2" href="/examples/vision/cct/">Compact Convolutional Transformers</a> <a class="nav-sublink2" href="/examples/vision/convmixer/">Image classification with ConvMixer</a> <a class="nav-sublink2" href="/examples/vision/eanet/">Image classification with EANet (External Attention Transformer)</a> <a class="nav-sublink2" href="/examples/vision/involution/">Involutional neural networks</a> <a class="nav-sublink2" href="/examples/vision/perceiver_image_classification/">Image classification with Perceiver</a> <a class="nav-sublink2" href="/examples/vision/reptile/">Few-Shot learning with Reptile</a> <a class="nav-sublink2" href="/examples/vision/semisupervised_simclr/">Semi-supervised image classification using contrastive pretraining with SimCLR</a> <a class="nav-sublink2" href="/examples/vision/swin_transformers/">Image classification with Swin Transformers</a> <a class="nav-sublink2" href="/examples/vision/vit_small_ds/">Train a Vision Transformer on small datasets</a> <a class="nav-sublink2" href="/examples/vision/shiftvit/">A Vision Transformer without Attention</a> <a class="nav-sublink2" href="/examples/vision/image_classification_using_global_context_vision_transformer/">Image Classification using Global Context Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/oxford_pets_image_segmentation/">Image segmentation with a U-Net-like architecture</a> <a class="nav-sublink2" href="/examples/vision/deeplabv3_plus/">Multiclass semantic segmentation using DeepLabV3+</a> <a class="nav-sublink2" href="/examples/vision/basnet_segmentation/">Highly accurate boundaries segmentation using BASNet</a> <a class="nav-sublink2" href="/examples/vision/fully_convolutional_network/">Image Segmentation using Composable Fully-Convolutional Networks</a> <a class="nav-sublink2" href="/examples/vision/retinanet/">Object Detection with RetinaNet</a> <a class="nav-sublink2" href="/examples/vision/keypoint_detection/">Keypoint Detection with Transfer Learning</a> <a class="nav-sublink2" href="/examples/vision/object_detection_using_vision_transformer/">Object detection with Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/3D_image_classification/">3D image classification from CT scans</a> <a class="nav-sublink2" href="/examples/vision/depth_estimation/">Monocular depth estimation</a> <a class="nav-sublink2" href="/examples/vision/nerf/">3D volumetric rendering with NeRF</a> <a class="nav-sublink2" href="/examples/vision/pointnet_segmentation/">Point cloud segmentation with PointNet</a> <a class="nav-sublink2" href="/examples/vision/pointnet/">Point cloud classification</a> <a class="nav-sublink2" href="/examples/vision/captcha_ocr/">OCR model for reading Captchas</a> <a class="nav-sublink2" href="/examples/vision/handwriting_recognition/">Handwriting recognition</a> <a class="nav-sublink2" href="/examples/vision/autoencoder/">Convolutional autoencoder for image denoising</a> <a class="nav-sublink2" href="/examples/vision/mirnet/">Low-light image enhancement using MIRNet</a> <a class="nav-sublink2" href="/examples/vision/super_resolution_sub_pixel/">Image Super-Resolution using an Efficient Sub-Pixel CNN</a> <a class="nav-sublink2" href="/examples/vision/edsr/">Enhanced Deep Residual Networks for single-image super-resolution</a> <a class="nav-sublink2" href="/examples/vision/zero_dce/">Zero-DCE for low-light image enhancement</a> <a class="nav-sublink2" href="/examples/vision/cutmix/">CutMix data augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/mixup/">MixUp augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/randaugment/">RandAugment for Image Classification for Improved Robustness</a> <a class="nav-sublink2" href="/examples/vision/image_captioning/">Image captioning</a> <a class="nav-sublink2" href="/examples/vision/nl_image_search/">Natural language image search with a Dual Encoder</a> <a class="nav-sublink2" href="/examples/vision/visualizing_what_convnets_learn/">Visualizing what convnets learn</a> <a class="nav-sublink2" href="/examples/vision/integrated_gradients/">Model interpretability with Integrated Gradients</a> <a class="nav-sublink2" href="/examples/vision/probing_vits/">Investigating Vision Transformer representations</a> <a class="nav-sublink2" href="/examples/vision/grad_cam/">Grad-CAM class activation visualization</a> <a class="nav-sublink2" href="/examples/vision/near_dup_search/">Near-duplicate image search</a> <a class="nav-sublink2" href="/examples/vision/semantic_image_clustering/">Semantic Image Clustering</a> <a class="nav-sublink2" href="/examples/vision/siamese_contrastive/">Image similarity estimation using a Siamese Network with a contrastive loss</a> <a class="nav-sublink2" href="/examples/vision/siamese_network/">Image similarity estimation using a Siamese Network with a triplet loss</a> <a class="nav-sublink2" href="/examples/vision/metric_learning/">Metric learning for image similarity search</a> <a class="nav-sublink2" href="/examples/vision/metric_learning_tf_similarity/">Metric learning for image similarity search using TensorFlow Similarity</a> <a class="nav-sublink2" href="/examples/vision/nnclr/">Self-supervised contrastive learning with NNCLR</a> <a class="nav-sublink2" href="/examples/vision/video_classification/">Video Classification with a CNN-RNN Architecture</a> <a class="nav-sublink2" href="/examples/vision/conv_lstm/">Next-Frame Video Prediction with Convolutional LSTMs</a> <a class="nav-sublink2" href="/examples/vision/video_transformers/">Video Classification with Transformers</a> <a class="nav-sublink2" href="/examples/vision/vivit/">Video Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/bit/">Image Classification using BigTransfer (BiT)</a> <a class="nav-sublink2" href="/examples/vision/gradient_centralization/">Gradient Centralization for Better Training Performance</a> <a class="nav-sublink2" href="/examples/vision/token_learner/">Learning to tokenize in Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/knowledge_distillation/">Knowledge Distillation</a> <a class="nav-sublink2" href="/examples/vision/fixres/">FixRes: Fixing train-test resolution discrepancy</a> <a class="nav-sublink2" href="/examples/vision/cait/">Class Attention Image Transformers with LayerScale</a> <a class="nav-sublink2" href="/examples/vision/patch_convnet/">Augmenting convnets with aggregated attention</a> <a class="nav-sublink2" href="/examples/vision/learnable_resizer/">Learning to Resize</a> <a class="nav-sublink2" href="/examples/vision/adamatch/">Semi-supervision and domain adaptation with AdaMatch</a> <a class="nav-sublink2" href="/examples/vision/barlow_twins/">Barlow Twins for Contrastive SSL</a> <a class="nav-sublink2" href="/examples/vision/consistency_training/">Consistency training with supervision</a> <a class="nav-sublink2" href="/examples/vision/deit/">Distilling Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/focal_modulation_network/">Focal Modulation: A replacement for Self-Attention</a> <a class="nav-sublink2" href="/examples/vision/forwardforward/">Using the Forward-Forward Algorithm for Image Classification</a> <a class="nav-sublink2" href="/examples/vision/masked_image_modeling/">Masked image modeling with Autoencoders</a> <a class="nav-sublink2" href="/examples/vision/sam/">Segment Anything Model with 🤗Transformers</a> <a class="nav-sublink2" href="/examples/vision/segformer/">Semantic segmentation with SegFormer and Hugging Face Transformers</a> <a class="nav-sublink2 active" href="/examples/vision/simsiam/">Self-supervised contrastive learning with SimSiam</a> <a class="nav-sublink2" href="/examples/vision/supervised-contrastive-learning/">Supervised Contrastive Learning</a> <a class="nav-sublink2" href="/examples/vision/temporal_latent_bottleneck/">When Recurrence meets Transformers</a> <a class="nav-sublink2" href="/examples/vision/yolov8/">Efficient Object Detection with YOLOV8 and KerasCV</a> <a class="nav-sublink" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparameter Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> <a class="nav-link" href="/keras_cv/" role="tab" aria-selected="">KerasCV: Computer Vision Workflows</a> <a class="nav-link" href="/keras_nlp/" role="tab" aria-selected="">KerasNLP: Natural Language Workflows</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/vision/'>Computer Vision</a> / Self-supervised contrastive learning with SimSiam </div> <div class='k-content'> <h1 id="selfsupervised-contrastive-learning-with-simsiam">Self-supervised contrastive learning with SimSiam</h1> <p><strong>Author:</strong> <a href="https://twitter.com/RisingSayak">Sayak Paul</a><br> <strong>Date created:</strong> 2021/03/19<br> <strong>Last modified:</strong> 2023/12/29<br> <strong>Description:</strong> Implementation of a self-supervised learning method for computer vision.</p> <div class='example_version_banner keras_2'>ⓘ This example uses Keras 2</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/vision/ipynb/simsiam.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/examples/vision/simsiam.py"><strong>GitHub source</strong></a></p> <p>Self-supervised learning (SSL) is an interesting branch of study in the field of representation learning. SSL systems try to formulate a supervised signal from a corpus of unlabeled data points. An example is we train a deep neural network to predict the next word from a given set of words. In literature, these tasks are known as <em>pretext tasks</em> or <em>auxiliary tasks</em>. If we <a href="https://arxiv.org/abs/1801.06146">train such a network</a> on a huge dataset (such as the <a href="https://www.corpusdata.org/wikipedia.asp">Wikipedia text corpus</a>) it learns very effective representations that transfer well to downstream tasks. Language models like <a href="https://arxiv.org/abs/1810.04805">BERT</a>, <a href="https://arxiv.org/abs/2005.14165">GPT-3</a>, <a href="https://allennlp.org/elmo">ELMo</a> all benefit from this.</p> <p>Much like the language models we can train computer vision models using similar approaches. To make things work in computer vision, we need to formulate the learning tasks such that the underlying model (a deep neural network) is able to make sense of the semantic information present in vision data. One such task is to a model to <em>contrast</em> between two different versions of the same image. The hope is that in this way the model will have learn representations where the similar images are grouped as together possible while the dissimilar images are further away.</p> <p>In this example, we will be implementing one such system called <strong>SimSiam</strong> proposed in <a href="https://arxiv.org/abs/2011.10566">Exploring Simple Siamese Representation Learning</a>. It is implemented as the following:</p> <ol> <li>We create two different versions of the same dataset with a stochastic data augmentation pipeline. Note that the random initialization seed needs to be the same during create these versions.</li> <li>We take a ResNet without any classification head (<strong>backbone</strong>) and we add a shallow fully-connected network (<strong>projection head</strong>) on top of it. Collectively, this is known as the <strong>encoder</strong>.</li> <li>We pass the output of the encoder through a <strong>predictor</strong> which is again a shallow fully-connected network having an <a href="https://en.wikipedia.org/wiki/Autoencoder">AutoEncoder</a> like structure.</li> <li>We then train our encoder to maximize the cosine similarity between the two different versions of our dataset.</li> </ol> <hr /> <h2 id="setup">Setup</h2> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">os</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s2">"KERAS_BACKEND"</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"tensorflow"</span> <span class="kn">import</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="nn">keras_cv</span> <span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">ops</span> <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span> <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> </code></pre></div> <hr /> <h2 id="define-hyperparameters">Define hyperparameters</h2> <div class="codehilite"><pre><span></span><code><span class="n">AUTO</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">AUTOTUNE</span> <span class="n">BATCH_SIZE</span> <span class="o">=</span> <span class="mi">128</span> <span class="n">EPOCHS</span> <span class="o">=</span> <span class="mi">5</span> <span class="n">CROP_TO</span> <span class="o">=</span> <span class="mi">32</span> <span class="n">SEED</span> <span class="o">=</span> <span class="mi">26</span> <span class="n">PROJECT_DIM</span> <span class="o">=</span> <span class="mi">2048</span> <span class="n">LATENT_DIM</span> <span class="o">=</span> <span class="mi">512</span> <span class="n">WEIGHT_DECAY</span> <span class="o">=</span> <span class="mf">0.0005</span> </code></pre></div> <hr /> <h2 id="load-the-cifar10-dataset">Load the CIFAR-10 dataset</h2> <div class="codehilite"><pre><span></span><code><span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">),</span> <span class="p">(</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">)</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">datasets</span><span class="o">.</span><span class="n">cifar10</span><span class="o">.</span><span class="n">load_data</span><span class="p">()</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Total training examples: </span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">x_train</span><span class="p">)</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Total test examples: </span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">x_test</span><span class="p">)</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Total training examples: 50000 Total test examples: 10000 </code></pre></div> </div> <hr /> <h2 id="defining-our-data-augmentation-pipeline">Defining our data augmentation pipeline</h2> <p>As studied in <a href="https://arxiv.org/abs/2002.05709">SimCLR</a> having the right data augmentation pipeline is critical for SSL systems to work effectively in computer vision. Two particular augmentation transforms that seem to matter the most are: 1.) Random resized crops and 2.) Color distortions. Most of the other SSL systems for computer vision (such as <a href="https://arxiv.org/abs/2006.07733">BYOL</a>, <a href="https://arxiv.org/abs/2003.04297">MoCoV2</a>, <a href="https://arxiv.org/abs/2006.09882">SwAV</a>, etc.) include these in their training pipelines.</p> <div class="codehilite"><pre><span></span><code><span class="n">strength</span> <span class="o">=</span> <span class="p">[</span><span class="mf">0.4</span><span class="p">,</span> <span class="mf">0.4</span><span class="p">,</span> <span class="mf">0.4</span><span class="p">,</span> <span class="mf">0.1</span><span class="p">]</span> <span class="n">random_flip</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">RandomFlip</span><span class="p">(</span><span class="n">mode</span><span class="o">=</span><span class="s2">"horizontal_and_vertical"</span><span class="p">)</span> <span class="n">random_crop</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">RandomCrop</span><span class="p">(</span><span class="n">CROP_TO</span><span class="p">,</span> <span class="n">CROP_TO</span><span class="p">)</span> <span class="n">random_brightness</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">RandomBrightness</span><span class="p">(</span><span class="mf">0.8</span> <span class="o">*</span> <span class="n">strength</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="n">random_contrast</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">RandomContrast</span><span class="p">((</span><span class="mi">1</span> <span class="o">-</span> <span class="mf">0.8</span> <span class="o">*</span> <span class="n">strength</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="mi">1</span> <span class="o">+</span> <span class="mf">0.8</span> <span class="o">*</span> <span class="n">strength</span><span class="p">[</span><span class="mi">1</span><span class="p">]))</span> <span class="n">random_saturation</span> <span class="o">=</span> <span class="n">keras_cv</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">RandomSaturation</span><span class="p">(</span> <span class="p">(</span><span class="mf">0.5</span> <span class="o">-</span> <span class="mf">0.8</span> <span class="o">*</span> <span class="n">strength</span><span class="p">[</span><span class="mi">2</span><span class="p">],</span> <span class="mf">0.5</span> <span class="o">+</span> <span class="mf">0.8</span> <span class="o">*</span> <span class="n">strength</span><span class="p">[</span><span class="mi">2</span><span class="p">])</span> <span class="p">)</span> <span class="n">random_hue</span> <span class="o">=</span> <span class="n">keras_cv</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">RandomHue</span><span class="p">(</span><span class="mf">0.2</span> <span class="o">*</span> <span class="n">strength</span><span class="p">[</span><span class="mi">3</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span><span class="mi">255</span><span class="p">])</span> <span class="n">grayscale</span> <span class="o">=</span> <span class="n">keras_cv</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Grayscale</span><span class="p">()</span> <span class="k">def</span> <span class="nf">flip_random_crop</span><span class="p">(</span><span class="n">image</span><span class="p">):</span> <span class="c1"># With random crops we also apply horizontal flipping.</span> <span class="n">image</span> <span class="o">=</span> <span class="n">random_flip</span><span class="p">(</span><span class="n">image</span><span class="p">)</span> <span class="n">image</span> <span class="o">=</span> <span class="n">random_crop</span><span class="p">(</span><span class="n">image</span><span class="p">)</span> <span class="k">return</span> <span class="n">image</span> <span class="k">def</span> <span class="nf">color_jitter</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">strength</span><span class="o">=</span><span class="p">[</span><span class="mf">0.4</span><span class="p">,</span> <span class="mf">0.4</span><span class="p">,</span> <span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.1</span><span class="p">]):</span> <span class="n">x</span> <span class="o">=</span> <span class="n">random_brightness</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">random_contrast</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">random_saturation</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">random_hue</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># Affine transformations can disturb the natural range of</span> <span class="c1"># RGB images, hence this is needed.</span> <span class="n">x</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">255</span><span class="p">)</span> <span class="k">return</span> <span class="n">x</span> <span class="k">def</span> <span class="nf">color_drop</span><span class="p">(</span><span class="n">x</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="n">grayscale</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">tile</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">])</span> <span class="k">return</span> <span class="n">x</span> <span class="k">def</span> <span class="nf">random_apply</span><span class="p">(</span><span class="n">func</span><span class="p">,</span> <span class="n">x</span><span class="p">,</span> <span class="n">p</span><span class="p">):</span> <span class="k">if</span> <span class="n">keras</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">([],</span> <span class="n">minval</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">maxval</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span> <span class="o"><</span> <span class="n">p</span><span class="p">:</span> <span class="k">return</span> <span class="n">func</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="k">else</span><span class="p">:</span> <span class="k">return</span> <span class="n">x</span> <span class="k">def</span> <span class="nf">custom_augment</span><span class="p">(</span><span class="n">image</span><span class="p">):</span> <span class="c1"># As discussed in the SimCLR paper, the series of augmentation</span> <span class="c1"># transformations (except for random crops) need to be applied</span> <span class="c1"># randomly to impose translational invariance.</span> <span class="n">image</span> <span class="o">=</span> <span class="n">flip_random_crop</span><span class="p">(</span><span class="n">image</span><span class="p">)</span> <span class="n">image</span> <span class="o">=</span> <span class="n">random_apply</span><span class="p">(</span><span class="n">color_jitter</span><span class="p">,</span> <span class="n">image</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="mf">0.8</span><span class="p">)</span> <span class="n">image</span> <span class="o">=</span> <span class="n">random_apply</span><span class="p">(</span><span class="n">color_drop</span><span class="p">,</span> <span class="n">image</span><span class="p">,</span> <span class="n">p</span><span class="o">=</span><span class="mf">0.2</span><span class="p">)</span> <span class="k">return</span> <span class="n">image</span> </code></pre></div> <p>It should be noted that an augmentation pipeline is generally dependent on various properties of the dataset we are dealing with. For example, if images in the dataset are heavily object-centric then taking random crops with a very high probability may hurt the training performance.</p> <p>Let's now apply our augmentation pipeline to our dataset and visualize a few outputs.</p> <hr /> <h2 id="convert-the-data-into-tensorflow-dataset-objects">Convert the data into TensorFlow <code>Dataset</code> objects</h2> <p>Here we create two different versions of our dataset <em>without</em> any ground-truth labels.</p> <div class="codehilite"><pre><span></span><code><span class="n">ssl_ds_one</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">from_tensor_slices</span><span class="p">(</span><span class="n">x_train</span><span class="p">)</span> <span class="n">ssl_ds_one</span> <span class="o">=</span> <span class="p">(</span> <span class="n">ssl_ds_one</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="mi">1024</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="n">SEED</span><span class="p">)</span> <span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">custom_augment</span><span class="p">,</span> <span class="n">num_parallel_calls</span><span class="o">=</span><span class="n">AUTO</span><span class="p">)</span> <span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="n">BATCH_SIZE</span><span class="p">)</span> <span class="o">.</span><span class="n">prefetch</span><span class="p">(</span><span class="n">AUTO</span><span class="p">)</span> <span class="p">)</span> <span class="n">ssl_ds_two</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">from_tensor_slices</span><span class="p">(</span><span class="n">x_train</span><span class="p">)</span> <span class="n">ssl_ds_two</span> <span class="o">=</span> <span class="p">(</span> <span class="n">ssl_ds_two</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="mi">1024</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="n">SEED</span><span class="p">)</span> <span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">custom_augment</span><span class="p">,</span> <span class="n">num_parallel_calls</span><span class="o">=</span><span class="n">AUTO</span><span class="p">)</span> <span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="n">BATCH_SIZE</span><span class="p">)</span> <span class="o">.</span><span class="n">prefetch</span><span class="p">(</span><span class="n">AUTO</span><span class="p">)</span> <span class="p">)</span> <span class="c1"># We then zip both of these datasets.</span> <span class="n">ssl_ds</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">zip</span><span class="p">((</span><span class="n">ssl_ds_one</span><span class="p">,</span> <span class="n">ssl_ds_two</span><span class="p">))</span> <span class="c1"># Visualize a few augmented images.</span> <span class="n">sample_images_one</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="nb">iter</span><span class="p">(</span><span class="n">ssl_ds_one</span><span class="p">))</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span> <span class="k">for</span> <span class="n">n</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">25</span><span class="p">):</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="n">n</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">sample_images_one</span><span class="p">[</span><span class="n">n</span><span class="p">]</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">"int"</span><span class="p">))</span> <span class="n">plt</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">"off"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> <span class="c1"># Ensure that the different versions of the dataset actually contain</span> <span class="c1"># identical images.</span> <span class="n">sample_images_two</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="nb">iter</span><span class="p">(</span><span class="n">ssl_ds_two</span><span class="p">))</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span> <span class="k">for</span> <span class="n">n</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">25</span><span class="p">):</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="n">n</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">sample_images_two</span><span class="p">[</span><span class="n">n</span><span class="p">]</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">"int"</span><span class="p">))</span> <span class="n">plt</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">"off"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> </code></pre></div> <p><img alt="png" src="/img/examples/vision/simsiam/simsiam_12_0.png" /></p> <p><img alt="png" src="/img/examples/vision/simsiam/simsiam_12_1.png" /></p> <p>Notice that the images in <code>samples_images_one</code> and <code>sample_images_two</code> are essentially the same but are augmented differently.</p> <hr /> <h2 id="defining-the-encoder-and-the-predictor">Defining the encoder and the predictor</h2> <p>We use an implementation of ResNet20 that is specifically configured for the CIFAR10 dataset. The code is taken from the <a href="https://github.com/GoogleCloudPlatform/keras-idiomatic-programmer/blob/master/zoo/resnet/resnet_cifar10_v2.py">keras-idiomatic-programmer</a> repository. The hyperparameters of these architectures have been referred from Section 3 and Appendix A of <a href="https://arxiv.org/abs/2011.10566">the original paper</a>.</p> <div class="codehilite"><pre><span></span><code><span class="err">!</span><span class="n">wget</span> <span class="o">-</span><span class="n">q</span> <span class="n">https</span><span class="p">:</span><span class="o">//</span><span class="n">git</span><span class="o">.</span><span class="n">io</span><span class="o">/</span><span class="n">JYx2x</span> <span class="o">-</span><span class="n">O</span> <span class="n">resnet_cifar10_v2</span><span class="o">.</span><span class="n">py</span> </code></pre></div> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">resnet_cifar10_v2</span> <span class="n">N</span> <span class="o">=</span> <span class="mi">2</span> <span class="n">DEPTH</span> <span class="o">=</span> <span class="n">N</span> <span class="o">*</span> <span class="mi">9</span> <span class="o">+</span> <span class="mi">2</span> <span class="n">NUM_BLOCKS</span> <span class="o">=</span> <span class="p">((</span><span class="n">DEPTH</span> <span class="o">-</span> <span class="mi">2</span><span class="p">)</span> <span class="o">//</span> <span class="mi">9</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span> <span class="k">def</span> <span class="nf">get_encoder</span><span class="p">():</span> <span class="c1"># Input and backbone.</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Input</span><span class="p">((</span><span class="n">CROP_TO</span><span class="p">,</span> <span class="n">CROP_TO</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Rescaling</span><span class="p">(</span><span class="n">scale</span><span class="o">=</span><span class="mf">1.0</span> <span class="o">/</span> <span class="mf">127.5</span><span class="p">,</span> <span class="n">offset</span><span class="o">=-</span><span class="mi">1</span><span class="p">)(</span> <span class="n">inputs</span> <span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">resnet_cifar10_v2</span><span class="o">.</span><span class="n">stem</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">resnet_cifar10_v2</span><span class="o">.</span><span class="n">learner</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">NUM_BLOCKS</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">GlobalAveragePooling2D</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">"backbone_pool"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># Projection head.</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span> <span class="n">PROJECT_DIM</span><span class="p">,</span> <span class="n">use_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">kernel_regularizer</span><span class="o">=</span><span class="n">regularizers</span><span class="o">.</span><span class="n">l2</span><span class="p">(</span><span class="n">WEIGHT_DECAY</span><span class="p">)</span> <span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">BatchNormalization</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span> <span class="n">PROJECT_DIM</span><span class="p">,</span> <span class="n">use_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">kernel_regularizer</span><span class="o">=</span><span class="n">regularizers</span><span class="o">.</span><span class="n">l2</span><span class="p">(</span><span class="n">WEIGHT_DECAY</span><span class="p">)</span> <span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">BatchNormalization</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"encoder"</span><span class="p">)</span> <span class="k">def</span> <span class="nf">get_predictor</span><span class="p">():</span> <span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span> <span class="p">[</span> <span class="c1"># Note the AutoEncoder-like structure.</span> <span class="n">layers</span><span class="o">.</span><span class="n">Input</span><span class="p">((</span><span class="n">PROJECT_DIM</span><span class="p">,)),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span> <span class="n">LATENT_DIM</span><span class="p">,</span> <span class="n">use_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">kernel_regularizer</span><span class="o">=</span><span class="n">regularizers</span><span class="o">.</span><span class="n">l2</span><span class="p">(</span><span class="n">WEIGHT_DECAY</span><span class="p">),</span> <span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">ReLU</span><span class="p">(),</span> <span class="n">layers</span><span class="o">.</span><span class="n">BatchNormalization</span><span class="p">(),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">PROJECT_DIM</span><span class="p">),</span> <span class="p">],</span> <span class="n">name</span><span class="o">=</span><span class="s2">"predictor"</span><span class="p">,</span> <span class="p">)</span> <span class="k">return</span> <span class="n">model</span> </code></pre></div> <hr /> <h2 id="defining-the-pretraining-loop">Defining the (pre-)training loop</h2> <p>One of the main reasons behind training networks with these kinds of approaches is to utilize the learned representations for downstream tasks like classification. This is why this particular training phase is also referred to as <em>pre-training</em>.</p> <p>We start by defining the loss function.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">compute_loss</span><span class="p">(</span><span class="n">p</span><span class="p">,</span> <span class="n">z</span><span class="p">):</span> <span class="c1"># The authors of SimSiam emphasize the impact of</span> <span class="c1"># the `stop_gradient` operator in the paper as it</span> <span class="c1"># has an important role in the overall optimization.</span> <span class="n">z</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">stop_gradient</span><span class="p">(</span><span class="n">z</span><span class="p">)</span> <span class="n">p</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">normalize</span><span class="p">(</span><span class="n">p</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">order</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span> <span class="n">z</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">normalize</span><span class="p">(</span><span class="n">z</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">order</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span> <span class="c1"># Negative cosine similarity (minimizing this is</span> <span class="c1"># equivalent to maximizing the similarity).</span> <span class="k">return</span> <span class="o">-</span><span class="n">ops</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">ops</span><span class="o">.</span><span class="n">sum</span><span class="p">((</span><span class="n">p</span> <span class="o">*</span> <span class="n">z</span><span class="p">),</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">))</span> </code></pre></div> <p>We then define our training loop by overriding the <code>train_step()</code> function of the <a href="/api/models/model#model-class"><code>keras.Model</code></a> class.</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">SimSiam</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">encoder</span><span class="p">,</span> <span class="n">predictor</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span> <span class="o">=</span> <span class="n">encoder</span> <span class="bp">self</span><span class="o">.</span><span class="n">predictor</span> <span class="o">=</span> <span class="n">predictor</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_tracker</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">Mean</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">"loss"</span><span class="p">)</span> <span class="nd">@property</span> <span class="k">def</span> <span class="nf">metrics</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="k">return</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">loss_tracker</span><span class="p">]</span> <span class="k">def</span> <span class="nf">train_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">):</span> <span class="c1"># Unpack the data.</span> <span class="n">ds_one</span><span class="p">,</span> <span class="n">ds_two</span> <span class="o">=</span> <span class="n">data</span> <span class="c1"># Forward pass through the encoder and predictor.</span> <span class="k">with</span> <span class="n">tf</span><span class="o">.</span><span class="n">GradientTape</span><span class="p">()</span> <span class="k">as</span> <span class="n">tape</span><span class="p">:</span> <span class="n">z1</span><span class="p">,</span> <span class="n">z2</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="p">(</span><span class="n">ds_one</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="p">(</span><span class="n">ds_two</span><span class="p">)</span> <span class="n">p1</span><span class="p">,</span> <span class="n">p2</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">predictor</span><span class="p">(</span><span class="n">z1</span><span class="p">),</span> <span class="bp">self</span><span class="o">.</span><span class="n">predictor</span><span class="p">(</span><span class="n">z2</span><span class="p">)</span> <span class="c1"># Note that here we are enforcing the network to match</span> <span class="c1"># the representations of two differently augmented batches</span> <span class="c1"># of data.</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">compute_loss</span><span class="p">(</span><span class="n">p1</span><span class="p">,</span> <span class="n">z2</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span> <span class="o">+</span> <span class="n">compute_loss</span><span class="p">(</span><span class="n">p2</span><span class="p">,</span> <span class="n">z1</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span> <span class="c1"># Compute gradients and update the parameters.</span> <span class="n">learnable_params</span> <span class="o">=</span> <span class="p">(</span> <span class="bp">self</span><span class="o">.</span><span class="n">encoder</span><span class="o">.</span><span class="n">trainable_variables</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">predictor</span><span class="o">.</span><span class="n">trainable_variables</span> <span class="p">)</span> <span class="n">gradients</span> <span class="o">=</span> <span class="n">tape</span><span class="o">.</span><span class="n">gradient</span><span class="p">(</span><span class="n">loss</span><span class="p">,</span> <span class="n">learnable_params</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="n">gradients</span><span class="p">,</span> <span class="n">learnable_params</span><span class="p">))</span> <span class="c1"># Monitor loss.</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_tracker</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">loss</span><span class="p">)</span> <span class="k">return</span> <span class="p">{</span><span class="s2">"loss"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_tracker</span><span class="o">.</span><span class="n">result</span><span class="p">()}</span> </code></pre></div> <hr /> <h2 id="pretraining-our-networks">Pre-training our networks</h2> <p>In the interest of this example, we will train the model for only 5 epochs. In reality, this should at least be 100 epochs.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Create a cosine decay learning scheduler.</span> <span class="n">num_training_samples</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">x_train</span><span class="p">)</span> <span class="n">steps</span> <span class="o">=</span> <span class="n">EPOCHS</span> <span class="o">*</span> <span class="p">(</span><span class="n">num_training_samples</span> <span class="o">//</span> <span class="n">BATCH_SIZE</span><span class="p">)</span> <span class="n">lr_decayed_fn</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">schedules</span><span class="o">.</span><span class="n">CosineDecay</span><span class="p">(</span> <span class="n">initial_learning_rate</span><span class="o">=</span><span class="mf">0.03</span><span class="p">,</span> <span class="n">decay_steps</span><span class="o">=</span><span class="n">steps</span> <span class="p">)</span> <span class="c1"># Create an early stopping callback.</span> <span class="n">early_stopping</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">callbacks</span><span class="o">.</span><span class="n">EarlyStopping</span><span class="p">(</span> <span class="n">monitor</span><span class="o">=</span><span class="s2">"loss"</span><span class="p">,</span> <span class="n">patience</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">restore_best_weights</span><span class="o">=</span><span class="kc">True</span> <span class="p">)</span> <span class="c1"># Compile model and start training.</span> <span class="n">simsiam</span> <span class="o">=</span> <span class="n">SimSiam</span><span class="p">(</span><span class="n">get_encoder</span><span class="p">(),</span> <span class="n">get_predictor</span><span class="p">())</span> <span class="n">simsiam</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">optimizer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">SGD</span><span class="p">(</span><span class="n">lr_decayed_fn</span><span class="p">,</span> <span class="n">momentum</span><span class="o">=</span><span class="mf">0.6</span><span class="p">))</span> <span class="n">history</span> <span class="o">=</span> <span class="n">simsiam</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">ssl_ds</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="n">EPOCHS</span><span class="p">,</span> <span class="n">callbacks</span><span class="o">=</span><span class="p">[</span><span class="n">early_stopping</span><span class="p">])</span> <span class="c1"># Visualize the training progress of the model.</span> <span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">history</span><span class="o">.</span><span class="n">history</span><span class="p">[</span><span class="s2">"loss"</span><span class="p">])</span> <span class="n">plt</span><span class="o">.</span><span class="n">grid</span><span class="p">()</span> <span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="s2">"Negative Cosine Similairty"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 1/5 391/391 [==============================] - 33s 42ms/step - loss: -0.8973 Epoch 2/5 391/391 [==============================] - 16s 40ms/step - loss: -0.9129 Epoch 3/5 391/391 [==============================] - 16s 40ms/step - loss: -0.9165 Epoch 4/5 391/391 [==============================] - 16s 40ms/step - loss: -0.9176 Epoch 5/5 391/391 [==============================] - 16s 40ms/step - loss: -0.9182 </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/simsiam/simsiam_22_1.png" /></p> <p>If your solution gets very close to -1 (minimum value of our loss) very quickly with a different dataset and a different backbone architecture that is likely because of <em>representation collapse</em>. It is a phenomenon where the encoder yields similar output for all the images. In that case additional hyperparameter tuning is required especially in the following areas:</p> <ul> <li>Strength of the color distortions and their probabilities.</li> <li>Learning rate and its schedule.</li> <li>Architecture of both the backbone and their projection head.</li> </ul> <hr /> <h2 id="evaluating-our-ssl-method">Evaluating our SSL method</h2> <p>The most popularly used method to evaluate a SSL method in computer vision (or any other pre-training method as such) is to learn a linear classifier on the frozen features of the trained backbone model (in this case it is ResNet20) and evaluate the classifier on unseen images. Other methods include <a href="https://keras.io/guides/transfer_learning/">fine-tuning</a> on the source dataset or even a target dataset with 5% or 10% labels present. Practically, we can use the backbone model for any downstream task such as semantic segmentation, object detection, and so on where the backbone models are usually pre-trained with <em>pure supervised learning</em>.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># We first create labeled `Dataset` objects.</span> <span class="n">train_ds</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">from_tensor_slices</span><span class="p">((</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">))</span> <span class="n">test_ds</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">from_tensor_slices</span><span class="p">((</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">))</span> <span class="c1"># Then we shuffle, batch, and prefetch this dataset for performance. We</span> <span class="c1"># also apply random resized crops as an augmentation but only to the</span> <span class="c1"># training set.</span> <span class="n">train_ds</span> <span class="o">=</span> <span class="p">(</span> <span class="n">train_ds</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="mi">1024</span><span class="p">)</span> <span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="p">(</span><span class="n">flip_random_crop</span><span class="p">(</span><span class="n">x</span><span class="p">),</span> <span class="n">y</span><span class="p">),</span> <span class="n">num_parallel_calls</span><span class="o">=</span><span class="n">AUTO</span><span class="p">)</span> <span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="n">BATCH_SIZE</span><span class="p">)</span> <span class="o">.</span><span class="n">prefetch</span><span class="p">(</span><span class="n">AUTO</span><span class="p">)</span> <span class="p">)</span> <span class="n">test_ds</span> <span class="o">=</span> <span class="n">test_ds</span><span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="n">BATCH_SIZE</span><span class="p">)</span><span class="o">.</span><span class="n">prefetch</span><span class="p">(</span><span class="n">AUTO</span><span class="p">)</span> <span class="c1"># Extract the backbone ResNet20.</span> <span class="n">backbone</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span> <span class="n">simsiam</span><span class="o">.</span><span class="n">encoder</span><span class="o">.</span><span class="n">input</span><span class="p">,</span> <span class="n">simsiam</span><span class="o">.</span><span class="n">encoder</span><span class="o">.</span><span class="n">get_layer</span><span class="p">(</span><span class="s2">"backbone_pool"</span><span class="p">)</span><span class="o">.</span><span class="n">output</span> <span class="p">)</span> <span class="c1"># We then create our linear classifier and train it.</span> <span class="n">backbone</span><span class="o">.</span><span class="n">trainable</span> <span class="o">=</span> <span class="kc">False</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Input</span><span class="p">((</span><span class="n">CROP_TO</span><span class="p">,</span> <span class="n">CROP_TO</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span> <span class="n">x</span> <span class="o">=</span> <span class="n">backbone</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"softmax"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">linear_model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"linear_model"</span><span class="p">)</span> <span class="c1"># Compile model and start training.</span> <span class="n">linear_model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">loss</span><span class="o">=</span><span class="s2">"sparse_categorical_crossentropy"</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="s2">"accuracy"</span><span class="p">],</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">SGD</span><span class="p">(</span><span class="n">lr_decayed_fn</span><span class="p">,</span> <span class="n">momentum</span><span class="o">=</span><span class="mf">0.9</span><span class="p">),</span> <span class="p">)</span> <span class="n">history</span> <span class="o">=</span> <span class="n">linear_model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span> <span class="n">train_ds</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="n">test_ds</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="n">EPOCHS</span><span class="p">,</span> <span class="n">callbacks</span><span class="o">=</span><span class="p">[</span><span class="n">early_stopping</span><span class="p">]</span> <span class="p">)</span> <span class="n">_</span><span class="p">,</span> <span class="n">test_acc</span> <span class="o">=</span> <span class="n">linear_model</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">test_ds</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Test accuracy: </span><span class="si">{:.2f}</span><span class="s2">%"</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">test_acc</span> <span class="o">*</span> <span class="mi">100</span><span class="p">))</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 1/5 391/391 [==============================] - 7s 11ms/step - loss: 3.8072 - accuracy: 0.1527 - val_loss: 3.7449 - val_accuracy: 0.2046 Epoch 2/5 391/391 [==============================] - 3s 8ms/step - loss: 3.7356 - accuracy: 0.2107 - val_loss: 3.7055 - val_accuracy: 0.2308 Epoch 3/5 391/391 [==============================] - 3s 8ms/step - loss: 3.7036 - accuracy: 0.2228 - val_loss: 3.6874 - val_accuracy: 0.2329 Epoch 4/5 391/391 [==============================] - 3s 8ms/step - loss: 3.6893 - accuracy: 0.2276 - val_loss: 3.6808 - val_accuracy: 0.2334 Epoch 5/5 391/391 [==============================] - 3s 9ms/step - loss: 3.6845 - accuracy: 0.2305 - val_loss: 3.6798 - val_accuracy: 0.2339 79/79 [==============================] - 1s 7ms/step - loss: 3.6798 - accuracy: 0.2339 Test accuracy: 23.39% </code></pre></div> </div> <hr /> <h2 id="notes">Notes</h2> <ul> <li>More data and longer pre-training schedule benefit SSL in general.</li> <li>SSL is particularly very helpful when you do not have access to very limited <em>labeled</em> training data but you can manage to build a large corpus of unlabeled data. Recently, using an SSL method called <a href="https://arxiv.org/abs/2006.09882">SwAV</a>, a group of researchers at Facebook trained a <a href="https://arxiv.org/abs/2006.09882">RegNet</a> on 2 Billion images. They were able to achieve downstream performance very close to those achieved by pure supervised pre-training. For some downstream tasks, their method even outperformed the supervised counterparts. You can check out <a href="https://arxiv.org/pdf/2103.01988.pdf">their paper</a> to know the details.</li> <li>If you are interested to understand why contrastive SSL helps networks learn meaningful representations, you can check out the following resources:<ul> <li><a href="https://ai.facebook.com/blog/self-supervised-learning-the-dark-matter-of-intelligence/">Self-supervised learning: The dark matter of intelligence</a></li> <li><a href="https://sslneuips20.github.io/files/CameraReadys%203-77/64/CameraReady/Understanding_self_supervised_learning.pdf">Understanding self-supervised learning using controlled datasets with known structure</a></li> </ul> </li> </ul> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#selfsupervised-contrastive-learning-with-simsiam'>Self-supervised contrastive learning with SimSiam</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#setup'>Setup</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#define-hyperparameters'>Define hyperparameters</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#load-the-cifar10-dataset'>Load the CIFAR-10 dataset</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#defining-our-data-augmentation-pipeline'>Defining our data augmentation pipeline</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#convert-the-data-into-tensorflow-dataset-objects'>Convert the data into TensorFlow <code>Dataset</code> objects</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#defining-the-encoder-and-the-predictor'>Defining the encoder and the predictor</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#defining-the-pretraining-loop'>Defining the (pre-)training loop</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#pretraining-our-networks'>Pre-training our networks</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#evaluating-our-ssl-method'>Evaluating our SSL method</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#notes'>Notes</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>