CINXE.COM

Image classification with modern MLP models

<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/vision/mlp_image_classification/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Image classification with modern MLP models"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Image classification with modern MLP models"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Image classification with modern MLP models</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink active" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink2" href="/examples/vision/image_classification_from_scratch/">Image classification from scratch</a> <a class="nav-sublink2" href="/examples/vision/mnist_convnet/">Simple MNIST convnet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_efficientnet_fine_tuning/">Image classification via fine-tuning with EfficientNet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_with_vision_transformer/">Image classification with Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/attention_mil_classification/">Classification using Attention-based Deep Multiple Instance Learning</a> <a class="nav-sublink2 active" href="/examples/vision/mlp_image_classification/">Image classification with modern MLP models</a> <a class="nav-sublink2" href="/examples/vision/mobilevit/">A mobile-friendly Transformer-based model for image classification</a> <a class="nav-sublink2" href="/examples/vision/xray_classification_with_tpus/">Pneumonia Classification on TPU</a> <a class="nav-sublink2" href="/examples/vision/cct/">Compact Convolutional Transformers</a> <a class="nav-sublink2" href="/examples/vision/convmixer/">Image classification with ConvMixer</a> <a class="nav-sublink2" href="/examples/vision/eanet/">Image classification with EANet (External Attention Transformer)</a> <a class="nav-sublink2" href="/examples/vision/involution/">Involutional neural networks</a> <a class="nav-sublink2" href="/examples/vision/perceiver_image_classification/">Image classification with Perceiver</a> <a class="nav-sublink2" href="/examples/vision/reptile/">Few-Shot learning with Reptile</a> <a class="nav-sublink2" href="/examples/vision/semisupervised_simclr/">Semi-supervised image classification using contrastive pretraining with SimCLR</a> <a class="nav-sublink2" href="/examples/vision/swin_transformers/">Image classification with Swin Transformers</a> <a class="nav-sublink2" href="/examples/vision/vit_small_ds/">Train a Vision Transformer on small datasets</a> <a class="nav-sublink2" href="/examples/vision/shiftvit/">A Vision Transformer without Attention</a> <a class="nav-sublink2" href="/examples/vision/image_classification_using_global_context_vision_transformer/">Image Classification using Global Context Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/temporal_latent_bottleneck/">When Recurrence meets Transformers</a> <a class="nav-sublink2" href="/examples/vision/oxford_pets_image_segmentation/">Image segmentation with a U-Net-like architecture</a> <a class="nav-sublink2" href="/examples/vision/deeplabv3_plus/">Multiclass semantic segmentation using DeepLabV3+</a> <a class="nav-sublink2" href="/examples/vision/basnet_segmentation/">Highly accurate boundaries segmentation using BASNet</a> <a class="nav-sublink2" href="/examples/vision/fully_convolutional_network/">Image Segmentation using Composable Fully-Convolutional Networks</a> <a class="nav-sublink2" href="/examples/vision/retinanet/">Object Detection with RetinaNet</a> <a class="nav-sublink2" href="/examples/vision/keypoint_detection/">Keypoint Detection with Transfer Learning</a> <a class="nav-sublink2" href="/examples/vision/object_detection_using_vision_transformer/">Object detection with Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/3D_image_classification/">3D image classification from CT scans</a> <a class="nav-sublink2" href="/examples/vision/depth_estimation/">Monocular depth estimation</a> <a class="nav-sublink2" href="/examples/vision/nerf/">3D volumetric rendering with NeRF</a> <a class="nav-sublink2" href="/examples/vision/pointnet_segmentation/">Point cloud segmentation with PointNet</a> <a class="nav-sublink2" href="/examples/vision/pointnet/">Point cloud classification</a> <a class="nav-sublink2" href="/examples/vision/captcha_ocr/">OCR model for reading Captchas</a> <a class="nav-sublink2" href="/examples/vision/handwriting_recognition/">Handwriting recognition</a> <a class="nav-sublink2" href="/examples/vision/autoencoder/">Convolutional autoencoder for image denoising</a> <a class="nav-sublink2" href="/examples/vision/mirnet/">Low-light image enhancement using MIRNet</a> <a class="nav-sublink2" href="/examples/vision/super_resolution_sub_pixel/">Image Super-Resolution using an Efficient Sub-Pixel CNN</a> <a class="nav-sublink2" href="/examples/vision/edsr/">Enhanced Deep Residual Networks for single-image super-resolution</a> <a class="nav-sublink2" href="/examples/vision/zero_dce/">Zero-DCE for low-light image enhancement</a> <a class="nav-sublink2" href="/examples/vision/cutmix/">CutMix data augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/mixup/">MixUp augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/randaugment/">RandAugment for Image Classification for Improved Robustness</a> <a class="nav-sublink2" href="/examples/vision/image_captioning/">Image captioning</a> <a class="nav-sublink2" href="/examples/vision/nl_image_search/">Natural language image search with a Dual Encoder</a> <a class="nav-sublink2" href="/examples/vision/visualizing_what_convnets_learn/">Visualizing what convnets learn</a> <a class="nav-sublink2" href="/examples/vision/integrated_gradients/">Model interpretability with Integrated Gradients</a> <a class="nav-sublink2" href="/examples/vision/probing_vits/">Investigating Vision Transformer representations</a> <a class="nav-sublink2" href="/examples/vision/grad_cam/">Grad-CAM class activation visualization</a> <a class="nav-sublink2" href="/examples/vision/near_dup_search/">Near-duplicate image search</a> <a class="nav-sublink2" href="/examples/vision/semantic_image_clustering/">Semantic Image Clustering</a> <a class="nav-sublink2" href="/examples/vision/siamese_contrastive/">Image similarity estimation using a Siamese Network with a contrastive loss</a> <a class="nav-sublink2" href="/examples/vision/siamese_network/">Image similarity estimation using a Siamese Network with a triplet loss</a> <a class="nav-sublink2" href="/examples/vision/metric_learning/">Metric learning for image similarity search</a> <a class="nav-sublink2" href="/examples/vision/metric_learning_tf_similarity/">Metric learning for image similarity search using TensorFlow Similarity</a> <a class="nav-sublink2" href="/examples/vision/nnclr/">Self-supervised contrastive learning with NNCLR</a> <a class="nav-sublink2" href="/examples/vision/video_classification/">Video Classification with a CNN-RNN Architecture</a> <a class="nav-sublink2" href="/examples/vision/conv_lstm/">Next-Frame Video Prediction with Convolutional LSTMs</a> <a class="nav-sublink2" href="/examples/vision/video_transformers/">Video Classification with Transformers</a> <a class="nav-sublink2" href="/examples/vision/vivit/">Video Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/bit/">Image Classification using BigTransfer (BiT)</a> <a class="nav-sublink2" href="/examples/vision/gradient_centralization/">Gradient Centralization for Better Training Performance</a> <a class="nav-sublink2" href="/examples/vision/token_learner/">Learning to tokenize in Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/knowledge_distillation/">Knowledge Distillation</a> <a class="nav-sublink2" href="/examples/vision/fixres/">FixRes: Fixing train-test resolution discrepancy</a> <a class="nav-sublink2" href="/examples/vision/cait/">Class Attention Image Transformers with LayerScale</a> <a class="nav-sublink2" href="/examples/vision/patch_convnet/">Augmenting convnets with aggregated attention</a> <a class="nav-sublink2" href="/examples/vision/learnable_resizer/">Learning to Resize</a> <a class="nav-sublink2" href="/examples/vision/adamatch/">Semi-supervision and domain adaptation with AdaMatch</a> <a class="nav-sublink2" href="/examples/vision/barlow_twins/">Barlow Twins for Contrastive SSL</a> <a class="nav-sublink2" href="/examples/vision/consistency_training/">Consistency training with supervision</a> <a class="nav-sublink2" href="/examples/vision/deit/">Distilling Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/focal_modulation_network/">Focal Modulation: A replacement for Self-Attention</a> <a class="nav-sublink2" href="/examples/vision/forwardforward/">Using the Forward-Forward Algorithm for Image Classification</a> <a class="nav-sublink2" href="/examples/vision/masked_image_modeling/">Masked image modeling with Autoencoders</a> <a class="nav-sublink2" href="/examples/vision/sam/">Segment Anything Model with 🤗Transformers</a> <a class="nav-sublink2" href="/examples/vision/segformer/">Semantic segmentation with SegFormer and Hugging Face Transformers</a> <a class="nav-sublink2" href="/examples/vision/simsiam/">Self-supervised contrastive learning with SimSiam</a> <a class="nav-sublink2" href="/examples/vision/supervised-contrastive-learning/">Supervised Contrastive Learning</a> <a class="nav-sublink2" href="/examples/vision/yolov8/">Efficient Object Detection with YOLOV8 and KerasCV</a> <a class="nav-sublink" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparam Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/vision/'>Computer Vision</a> / Image classification with modern MLP models </div> <div class='k-content'> <h1 id="image-classification-with-modern-mlp-models">Image classification with modern MLP models</h1> <p><strong>Author:</strong> <a href="https://www.linkedin.com/in/khalid-salama-24403144/">Khalid Salama</a><br> <strong>Date created:</strong> 2021/05/30<br> <strong>Last modified:</strong> 2023/08/03<br> <strong>Description:</strong> Implementing the MLP-Mixer, FNet, and gMLP models for CIFAR-100 image classification.</p> <div class='example_version_banner keras_3'>ⓘ This example uses Keras 3</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/vision/ipynb/mlp_image_classification.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/examples/vision/mlp_image_classification.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="introduction">Introduction</h2> <p>This example implements three modern attention-free, multi-layer perceptron (MLP) based models for image classification, demonstrated on the CIFAR-100 dataset:</p> <ol> <li>The <a href="https://arxiv.org/abs/2105.01601">MLP-Mixer</a> model, by Ilya Tolstikhin et al., based on two types of MLPs.</li> <li>The <a href="https://arxiv.org/abs/2105.03824">FNet</a> model, by James Lee-Thorp et al., based on unparameterized Fourier Transform.</li> <li>The <a href="https://arxiv.org/abs/2105.08050">gMLP</a> model, by Hanxiao Liu et al., based on MLP with gating.</li> </ol> <p>The purpose of the example is not to compare between these models, as they might perform differently on different datasets with well-tuned hyperparameters. Rather, it is to show simple implementations of their main building blocks.</p> <hr /> <h2 id="setup">Setup</h2> <div class="codehilite"><pre><span></span><code><span class="kn">import</span><span class="w"> </span><span class="nn">numpy</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">np</span> <span class="kn">import</span><span class="w"> </span><span class="nn">keras</span> <span class="kn">from</span><span class="w"> </span><span class="nn">keras</span><span class="w"> </span><span class="kn">import</span> <span class="n">layers</span> </code></pre></div> <hr /> <h2 id="prepare-the-data">Prepare the data</h2> <div class="codehilite"><pre><span></span><code><span class="n">num_classes</span> <span class="o">=</span> <span class="mi">100</span> <span class="n">input_shape</span> <span class="o">=</span> <span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span> <span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">),</span> <span class="p">(</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">)</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">datasets</span><span class="o">.</span><span class="n">cifar100</span><span class="o">.</span><span class="n">load_data</span><span class="p">()</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;x_train shape: </span><span class="si">{</span><span class="n">x_train</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2"> - y_train shape: </span><span class="si">{</span><span class="n">y_train</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;x_test shape: </span><span class="si">{</span><span class="n">x_test</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2"> - y_test shape: </span><span class="si">{</span><span class="n">y_test</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>x_train shape: (50000, 32, 32, 3) - y_train shape: (50000, 1) x_test shape: (10000, 32, 32, 3) - y_test shape: (10000, 1) </code></pre></div> </div> <hr /> <h2 id="configure-the-hyperparameters">Configure the hyperparameters</h2> <div class="codehilite"><pre><span></span><code><span class="n">weight_decay</span> <span class="o">=</span> <span class="mf">0.0001</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="mi">128</span> <span class="n">num_epochs</span> <span class="o">=</span> <span class="mi">1</span> <span class="c1"># Recommended num_epochs = 50</span> <span class="n">dropout_rate</span> <span class="o">=</span> <span class="mf">0.2</span> <span class="n">image_size</span> <span class="o">=</span> <span class="mi">64</span> <span class="c1"># We&#39;ll resize input images to this size.</span> <span class="n">patch_size</span> <span class="o">=</span> <span class="mi">8</span> <span class="c1"># Size of the patches to be extracted from the input images.</span> <span class="n">num_patches</span> <span class="o">=</span> <span class="p">(</span><span class="n">image_size</span> <span class="o">//</span> <span class="n">patch_size</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span> <span class="c1"># Size of the data array.</span> <span class="n">embedding_dim</span> <span class="o">=</span> <span class="mi">256</span> <span class="c1"># Number of hidden units.</span> <span class="n">num_blocks</span> <span class="o">=</span> <span class="mi">4</span> <span class="c1"># Number of blocks.</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Image size: </span><span class="si">{</span><span class="n">image_size</span><span class="si">}</span><span class="s2"> X </span><span class="si">{</span><span class="n">image_size</span><span class="si">}</span><span class="s2"> = </span><span class="si">{</span><span class="n">image_size</span><span class="w"> </span><span class="o">**</span><span class="w"> </span><span class="mi">2</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Patch size: </span><span class="si">{</span><span class="n">patch_size</span><span class="si">}</span><span class="s2"> X </span><span class="si">{</span><span class="n">patch_size</span><span class="si">}</span><span class="s2"> = </span><span class="si">{</span><span class="n">patch_size</span><span class="w"> </span><span class="o">**</span><span class="w"> </span><span class="mi">2</span><span class="si">}</span><span class="s2"> &quot;</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Patches per image: </span><span class="si">{</span><span class="n">num_patches</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Elements per patch (3 channels): </span><span class="si">{</span><span class="p">(</span><span class="n">patch_size</span><span class="w"> </span><span class="o">**</span><span class="w"> </span><span class="mi">2</span><span class="p">)</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="mi">3</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Image size: 64 X 64 = 4096 Patch size: 8 X 8 = 64 Patches per image: 64 Elements per patch (3 channels): 192 </code></pre></div> </div> <hr /> <h2 id="build-a-classification-model">Build a classification model</h2> <p>We implement a method that builds a classifier given the processing blocks.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="nf">build_classifier</span><span class="p">(</span><span class="n">blocks</span><span class="p">,</span> <span class="n">positional_encoding</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="n">input_shape</span><span class="p">)</span> <span class="c1"># Augment data.</span> <span class="n">augmented</span> <span class="o">=</span> <span class="n">data_augmentation</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="c1"># Create patches.</span> <span class="n">patches</span> <span class="o">=</span> <span class="n">Patches</span><span class="p">(</span><span class="n">patch_size</span><span class="p">)(</span><span class="n">augmented</span><span class="p">)</span> <span class="c1"># Encode patches to generate a [batch_size, num_patches, embedding_dim] tensor.</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">units</span><span class="o">=</span><span class="n">embedding_dim</span><span class="p">)(</span><span class="n">patches</span><span class="p">)</span> <span class="k">if</span> <span class="n">positional_encoding</span><span class="p">:</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">PositionEmbedding</span><span class="p">(</span><span class="n">sequence_length</span><span class="o">=</span><span class="n">num_patches</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># Process x using the module blocks.</span> <span class="n">x</span> <span class="o">=</span> <span class="n">blocks</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># Apply global average pooling to generate a [batch_size, embedding_dim] representation tensor.</span> <span class="n">representation</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">GlobalAveragePooling1D</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># Apply dropout.</span> <span class="n">representation</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">rate</span><span class="o">=</span><span class="n">dropout_rate</span><span class="p">)(</span><span class="n">representation</span><span class="p">)</span> <span class="c1"># Compute logits outputs.</span> <span class="n">logits</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">num_classes</span><span class="p">)(</span><span class="n">representation</span><span class="p">)</span> <span class="c1"># Create the Keras model.</span> <span class="k">return</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="o">=</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="o">=</span><span class="n">logits</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="define-an-experiment">Define an experiment</h2> <p>We implement a utility function to compile, train, and evaluate a given model.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="nf">run_experiment</span><span class="p">(</span><span class="n">model</span><span class="p">):</span> <span class="c1"># Create Adam optimizer with weight decay.</span> <span class="n">optimizer</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">AdamW</span><span class="p">(</span> <span class="n">learning_rate</span><span class="o">=</span><span class="n">learning_rate</span><span class="p">,</span> <span class="n">weight_decay</span><span class="o">=</span><span class="n">weight_decay</span><span class="p">,</span> <span class="p">)</span> <span class="c1"># Compile the model.</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">optimizer</span><span class="p">,</span> <span class="n">loss</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">SparseCategoricalCrossentropy</span><span class="p">(</span><span class="n">from_logits</span><span class="o">=</span><span class="kc">True</span><span class="p">),</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span> <span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">SparseCategoricalAccuracy</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">&quot;acc&quot;</span><span class="p">),</span> <span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">SparseTopKCategoricalAccuracy</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;top5-acc&quot;</span><span class="p">),</span> <span class="p">],</span> <span class="p">)</span> <span class="c1"># Create a learning rate scheduler callback.</span> <span class="n">reduce_lr</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">callbacks</span><span class="o">.</span><span class="n">ReduceLROnPlateau</span><span class="p">(</span> <span class="n">monitor</span><span class="o">=</span><span class="s2">&quot;val_loss&quot;</span><span class="p">,</span> <span class="n">factor</span><span class="o">=</span><span class="mf">0.5</span><span class="p">,</span> <span class="n">patience</span><span class="o">=</span><span class="mi">5</span> <span class="p">)</span> <span class="c1"># Create an early stopping callback.</span> <span class="n">early_stopping</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">callbacks</span><span class="o">.</span><span class="n">EarlyStopping</span><span class="p">(</span> <span class="n">monitor</span><span class="o">=</span><span class="s2">&quot;val_loss&quot;</span><span class="p">,</span> <span class="n">patience</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">restore_best_weights</span><span class="o">=</span><span class="kc">True</span> <span class="p">)</span> <span class="c1"># Fit the model.</span> <span class="n">history</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span> <span class="n">x</span><span class="o">=</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y</span><span class="o">=</span><span class="n">y_train</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="n">num_epochs</span><span class="p">,</span> <span class="n">validation_split</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">callbacks</span><span class="o">=</span><span class="p">[</span><span class="n">early_stopping</span><span class="p">,</span> <span class="n">reduce_lr</span><span class="p">],</span> <span class="n">verbose</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="p">)</span> <span class="n">_</span><span class="p">,</span> <span class="n">accuracy</span><span class="p">,</span> <span class="n">top_5_accuracy</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Test accuracy: </span><span class="si">{</span><span class="nb">round</span><span class="p">(</span><span class="n">accuracy</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="mi">100</span><span class="p">,</span><span class="w"> </span><span class="mi">2</span><span class="p">)</span><span class="si">}</span><span class="s2">%&quot;</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Test top 5 accuracy: </span><span class="si">{</span><span class="nb">round</span><span class="p">(</span><span class="n">top_5_accuracy</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="mi">100</span><span class="p">,</span><span class="w"> </span><span class="mi">2</span><span class="p">)</span><span class="si">}</span><span class="s2">%&quot;</span><span class="p">)</span> <span class="c1"># Return history to plot learning curves.</span> <span class="k">return</span> <span class="n">history</span> </code></pre></div> <hr /> <h2 id="use-data-augmentation">Use data augmentation</h2> <div class="codehilite"><pre><span></span><code><span class="n">data_augmentation</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span> <span class="p">[</span> <span class="n">layers</span><span class="o">.</span><span class="n">Normalization</span><span class="p">(),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Resizing</span><span class="p">(</span><span class="n">image_size</span><span class="p">,</span> <span class="n">image_size</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">RandomFlip</span><span class="p">(</span><span class="s2">&quot;horizontal&quot;</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">RandomZoom</span><span class="p">(</span><span class="n">height_factor</span><span class="o">=</span><span class="mf">0.2</span><span class="p">,</span> <span class="n">width_factor</span><span class="o">=</span><span class="mf">0.2</span><span class="p">),</span> <span class="p">],</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;data_augmentation&quot;</span><span class="p">,</span> <span class="p">)</span> <span class="c1"># Compute the mean and the variance of the training data for normalization.</span> <span class="n">data_augmentation</span><span class="o">.</span><span class="n">layers</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">adapt</span><span class="p">(</span><span class="n">x_train</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="implement-patch-extraction-as-a-layer">Implement patch extraction as a layer</h2> <div class="codehilite"><pre><span></span><code><span class="k">class</span><span class="w"> </span><span class="nc">Patches</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">patch_size</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">patch_size</span> <span class="o">=</span> <span class="n">patch_size</span> <span class="k">def</span><span class="w"> </span><span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="n">patches</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">extract_patches</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">patch_size</span><span class="p">)</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">patches</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span> <span class="n">num_patches</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">patches</span><span class="p">)[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">patches</span><span class="p">)[</span><span class="mi">2</span><span class="p">]</span> <span class="n">patch_dim</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">patches</span><span class="p">)[</span><span class="mi">3</span><span class="p">]</span> <span class="n">out</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">patches</span><span class="p">,</span> <span class="p">(</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">num_patches</span><span class="p">,</span> <span class="n">patch_dim</span><span class="p">))</span> <span class="k">return</span> <span class="n">out</span> </code></pre></div> <hr /> <h2 id="implement-position-embedding-as-a-layer">Implement position embedding as a layer</h2> <div class="codehilite"><pre><span></span><code><span class="k">class</span><span class="w"> </span><span class="nc">PositionEmbedding</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span> <span class="bp">self</span><span class="p">,</span> <span class="n">sequence_length</span><span class="p">,</span> <span class="n">initializer</span><span class="o">=</span><span class="s2">&quot;glorot_uniform&quot;</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">,</span> <span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="k">if</span> <span class="n">sequence_length</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="s2">&quot;`sequence_length` must be an Integer, received `None`.&quot;</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">sequence_length</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">sequence_length</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">initializer</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">initializers</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">initializer</span><span class="p">)</span> <span class="k">def</span><span class="w"> </span><span class="nf">get_config</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="n">config</span> <span class="o">=</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">get_config</span><span class="p">()</span> <span class="n">config</span><span class="o">.</span><span class="n">update</span><span class="p">(</span> <span class="p">{</span> <span class="s2">&quot;sequence_length&quot;</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">sequence_length</span><span class="p">,</span> <span class="s2">&quot;initializer&quot;</span><span class="p">:</span> <span class="n">keras</span><span class="o">.</span><span class="n">initializers</span><span class="o">.</span><span class="n">serialize</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">initializer</span><span class="p">),</span> <span class="p">}</span> <span class="p">)</span> <span class="k">return</span> <span class="n">config</span> <span class="k">def</span><span class="w"> </span><span class="nf">build</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_shape</span><span class="p">):</span> <span class="n">feature_size</span> <span class="o">=</span> <span class="n">input_shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="bp">self</span><span class="o">.</span><span class="n">position_embeddings</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">add_weight</span><span class="p">(</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;embeddings&quot;</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">sequence_length</span><span class="p">,</span> <span class="n">feature_size</span><span class="p">],</span> <span class="n">initializer</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">initializer</span><span class="p">,</span> <span class="n">trainable</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="p">)</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">build</span><span class="p">(</span><span class="n">input_shape</span><span class="p">)</span> <span class="k">def</span><span class="w"> </span><span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">start_index</span><span class="o">=</span><span class="mi">0</span><span class="p">):</span> <span class="n">shape</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">feature_length</span> <span class="o">=</span> <span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="n">sequence_length</span> <span class="o">=</span> <span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">2</span><span class="p">]</span> <span class="c1"># trim to match the length of the input sequence, which might be less</span> <span class="c1"># than the sequence_length of the layer.</span> <span class="n">position_embeddings</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">convert_to_tensor</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">position_embeddings</span><span class="p">)</span> <span class="n">position_embeddings</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">slice</span><span class="p">(</span> <span class="n">position_embeddings</span><span class="p">,</span> <span class="p">(</span><span class="n">start_index</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="p">(</span><span class="n">sequence_length</span><span class="p">,</span> <span class="n">feature_length</span><span class="p">),</span> <span class="p">)</span> <span class="k">return</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">broadcast_to</span><span class="p">(</span><span class="n">position_embeddings</span><span class="p">,</span> <span class="n">shape</span><span class="p">)</span> <span class="k">def</span><span class="w"> </span><span class="nf">compute_output_shape</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_shape</span><span class="p">):</span> <span class="k">return</span> <span class="n">input_shape</span> </code></pre></div> <hr /> <h2 id="the-mlpmixer-model">The MLP-Mixer model</h2> <p>The MLP-Mixer is an architecture based exclusively on multi-layer perceptrons (MLPs), that contains two types of MLP layers:</p> <ol> <li>One applied independently to image patches, which mixes the per-location features.</li> <li>The other applied across patches (along channels), which mixes spatial information.</li> </ol> <p>This is similar to a <a href="https://arxiv.org/abs/1610.02357">depthwise separable convolution based model</a> such as the Xception model, but with two chained dense transforms, no max pooling, and layer normalization instead of batch normalization.</p> <h3 id="implement-the-mlpmixer-module">Implement the MLP-Mixer module</h3> <div class="codehilite"><pre><span></span><code><span class="k">class</span><span class="w"> </span><span class="nc">MLPMixerLayer</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">num_patches</span><span class="p">,</span> <span class="n">hidden_units</span><span class="p">,</span> <span class="n">dropout_rate</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">mlp1</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span> <span class="p">[</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">units</span><span class="o">=</span><span class="n">num_patches</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">&quot;gelu&quot;</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">units</span><span class="o">=</span><span class="n">num_patches</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">rate</span><span class="o">=</span><span class="n">dropout_rate</span><span class="p">),</span> <span class="p">]</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">mlp2</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span> <span class="p">[</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">units</span><span class="o">=</span><span class="n">num_patches</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">&quot;gelu&quot;</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">units</span><span class="o">=</span><span class="n">hidden_units</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">rate</span><span class="o">=</span><span class="n">dropout_rate</span><span class="p">),</span> <span class="p">]</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">normalize</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">LayerNormalization</span><span class="p">(</span><span class="n">epsilon</span><span class="o">=</span><span class="mf">1e-6</span><span class="p">)</span> <span class="k">def</span><span class="w"> </span><span class="nf">build</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_shape</span><span class="p">):</span> <span class="k">return</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">build</span><span class="p">(</span><span class="n">input_shape</span><span class="p">)</span> <span class="k">def</span><span class="w"> </span><span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="c1"># Apply layer normalization.</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">normalize</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="c1"># Transpose inputs from [num_batches, num_patches, hidden_units] to [num_batches, hidden_units, num_patches].</span> <span class="n">x_channels</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axes</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="c1"># Apply mlp1 on each channel independently.</span> <span class="n">mlp1_outputs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">mlp1</span><span class="p">(</span><span class="n">x_channels</span><span class="p">)</span> <span class="c1"># Transpose mlp1_outputs from [num_batches, hidden_dim, num_patches] to [num_batches, num_patches, hidden_units].</span> <span class="n">mlp1_outputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">mlp1_outputs</span><span class="p">,</span> <span class="n">axes</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="c1"># Add skip connection.</span> <span class="n">x</span> <span class="o">=</span> <span class="n">mlp1_outputs</span> <span class="o">+</span> <span class="n">inputs</span> <span class="c1"># Apply layer normalization.</span> <span class="n">x_patches</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">normalize</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># Apply mlp2 on each patch independtenly.</span> <span class="n">mlp2_outputs</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">mlp2</span><span class="p">(</span><span class="n">x_patches</span><span class="p">)</span> <span class="c1"># Add skip connection.</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">mlp2_outputs</span> <span class="k">return</span> <span class="n">x</span> </code></pre></div> <h3 id="build-train-and-evaluate-the-mlpmixer-model">Build, train, and evaluate the MLP-Mixer model</h3> <p>Note that training the model with the current settings on a V100 GPUs takes around 8 seconds per epoch.</p> <div class="codehilite"><pre><span></span><code><span class="n">mlpmixer_blocks</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span> <span class="p">[</span><span class="n">MLPMixerLayer</span><span class="p">(</span><span class="n">num_patches</span><span class="p">,</span> <span class="n">embedding_dim</span><span class="p">,</span> <span class="n">dropout_rate</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_blocks</span><span class="p">)]</span> <span class="p">)</span> <span class="n">learning_rate</span> <span class="o">=</span> <span class="mf">0.005</span> <span class="n">mlpmixer_classifier</span> <span class="o">=</span> <span class="n">build_classifier</span><span class="p">(</span><span class="n">mlpmixer_blocks</span><span class="p">)</span> <span class="n">history</span> <span class="o">=</span> <span class="n">run_experiment</span><span class="p">(</span><span class="n">mlpmixer_classifier</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Test accuracy: 9.76% Test top 5 accuracy: 30.8% </code></pre></div> </div> <p>The MLP-Mixer model tends to have much less number of parameters compared to convolutional and transformer-based models, which leads to less training and serving computational cost.</p> <p>As mentioned in the <a href="https://arxiv.org/abs/2105.01601">MLP-Mixer</a> paper, when pre-trained on large datasets, or with modern regularization schemes, the MLP-Mixer attains competitive scores to state-of-the-art models. You can obtain better results by increasing the embedding dimensions, increasing the number of mixer blocks, and training the model for longer. You may also try to increase the size of the input images and use different patch sizes.</p> <hr /> <h2 id="the-fnet-model">The FNet model</h2> <p>The FNet uses a similar block to the Transformer block. However, FNet replaces the self-attention layer in the Transformer block with a parameter-free 2D Fourier transformation layer:</p> <ol> <li>One 1D Fourier Transform is applied along the patches.</li> <li>One 1D Fourier Transform is applied along the channels.</li> </ol> <h3 id="implement-the-fnet-module">Implement the FNet module</h3> <div class="codehilite"><pre><span></span><code><span class="k">class</span><span class="w"> </span><span class="nc">FNetLayer</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">embedding_dim</span><span class="p">,</span> <span class="n">dropout_rate</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">ffn</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span> <span class="p">[</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">units</span><span class="o">=</span><span class="n">embedding_dim</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">&quot;gelu&quot;</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">rate</span><span class="o">=</span><span class="n">dropout_rate</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">units</span><span class="o">=</span><span class="n">embedding_dim</span><span class="p">),</span> <span class="p">]</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">normalize1</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">LayerNormalization</span><span class="p">(</span><span class="n">epsilon</span><span class="o">=</span><span class="mf">1e-6</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">normalize2</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">LayerNormalization</span><span class="p">(</span><span class="n">epsilon</span><span class="o">=</span><span class="mf">1e-6</span><span class="p">)</span> <span class="k">def</span><span class="w"> </span><span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="c1"># Apply fourier transformations.</span> <span class="n">real_part</span> <span class="o">=</span> <span class="n">inputs</span> <span class="n">im_part</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">zeros_like</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">fft2</span><span class="p">((</span><span class="n">real_part</span><span class="p">,</span> <span class="n">im_part</span><span class="p">))[</span><span class="mi">0</span><span class="p">]</span> <span class="c1"># Add skip connection.</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">inputs</span> <span class="c1"># Apply layer normalization.</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">normalize1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># Apply Feedfowrad network.</span> <span class="n">x_ffn</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">ffn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># Add skip connection.</span> <span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">+</span> <span class="n">x_ffn</span> <span class="c1"># Apply layer normalization.</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">normalize2</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> </code></pre></div> <h3 id="build-train-and-evaluate-the-fnet-model">Build, train, and evaluate the FNet model</h3> <p>Note that training the model with the current settings on a V100 GPUs takes around 8 seconds per epoch.</p> <div class="codehilite"><pre><span></span><code><span class="n">fnet_blocks</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span> <span class="p">[</span><span class="n">FNetLayer</span><span class="p">(</span><span class="n">embedding_dim</span><span class="p">,</span> <span class="n">dropout_rate</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_blocks</span><span class="p">)]</span> <span class="p">)</span> <span class="n">learning_rate</span> <span class="o">=</span> <span class="mf">0.001</span> <span class="n">fnet_classifier</span> <span class="o">=</span> <span class="n">build_classifier</span><span class="p">(</span><span class="n">fnet_blocks</span><span class="p">,</span> <span class="n">positional_encoding</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="n">history</span> <span class="o">=</span> <span class="n">run_experiment</span><span class="p">(</span><span class="n">fnet_classifier</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Test accuracy: 13.82% Test top 5 accuracy: 36.15% </code></pre></div> </div> <p>As shown in the <a href="https://arxiv.org/abs/2105.03824">FNet</a> paper, better results can be achieved by increasing the embedding dimensions, increasing the number of FNet blocks, and training the model for longer. You may also try to increase the size of the input images and use different patch sizes. The FNet scales very efficiently to long inputs, runs much faster than attention-based Transformer models, and produces competitive accuracy results.</p> <hr /> <h2 id="the-gmlp-model">The gMLP model</h2> <p>The gMLP is a MLP architecture that features a Spatial Gating Unit (SGU). The SGU enables cross-patch interactions across the spatial (channel) dimension, by:</p> <ol> <li>Transforming the input spatially by applying linear projection across patches (along channels).</li> <li>Applying element-wise multiplication of the input and its spatial transformation.</li> </ol> <h3 id="implement-the-gmlp-module">Implement the gMLP module</h3> <div class="codehilite"><pre><span></span><code><span class="k">class</span><span class="w"> </span><span class="nc">gMLPLayer</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">num_patches</span><span class="p">,</span> <span class="n">embedding_dim</span><span class="p">,</span> <span class="n">dropout_rate</span><span class="p">,</span> <span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">*</span><span class="n">args</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">channel_projection1</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span> <span class="p">[</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">units</span><span class="o">=</span><span class="n">embedding_dim</span> <span class="o">*</span> <span class="mi">2</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">&quot;gelu&quot;</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">rate</span><span class="o">=</span><span class="n">dropout_rate</span><span class="p">),</span> <span class="p">]</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">channel_projection2</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">units</span><span class="o">=</span><span class="n">embedding_dim</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">spatial_projection</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span> <span class="n">units</span><span class="o">=</span><span class="n">num_patches</span><span class="p">,</span> <span class="n">bias_initializer</span><span class="o">=</span><span class="s2">&quot;Ones&quot;</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">normalize1</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">LayerNormalization</span><span class="p">(</span><span class="n">epsilon</span><span class="o">=</span><span class="mf">1e-6</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">normalize2</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">LayerNormalization</span><span class="p">(</span><span class="n">epsilon</span><span class="o">=</span><span class="mf">1e-6</span><span class="p">)</span> <span class="k">def</span><span class="w"> </span><span class="nf">spatial_gating_unit</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="c1"># Split x along the channel dimensions.</span> <span class="c1"># Tensors u and v will in the shape of [batch_size, num_patchs, embedding_dim].</span> <span class="n">u</span><span class="p">,</span> <span class="n">v</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">indices_or_sections</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span> <span class="c1"># Apply layer normalization.</span> <span class="n">v</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">normalize2</span><span class="p">(</span><span class="n">v</span><span class="p">)</span> <span class="c1"># Apply spatial projection.</span> <span class="n">v_channels</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">v</span><span class="p">,</span> <span class="n">axes</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="n">v_projected</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">spatial_projection</span><span class="p">(</span><span class="n">v_channels</span><span class="p">)</span> <span class="n">v_projected</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">v_projected</span><span class="p">,</span> <span class="n">axes</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">1</span><span class="p">))</span> <span class="c1"># Apply element-wise multiplication.</span> <span class="k">return</span> <span class="n">u</span> <span class="o">*</span> <span class="n">v_projected</span> <span class="k">def</span><span class="w"> </span><span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="c1"># Apply layer normalization.</span> <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">normalize1</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="c1"># Apply the first channel projection. x_projected shape: [batch_size, num_patches, embedding_dim * 2].</span> <span class="n">x_projected</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">channel_projection1</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># Apply the spatial gating unit. x_spatial shape: [batch_size, num_patches, embedding_dim].</span> <span class="n">x_spatial</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">spatial_gating_unit</span><span class="p">(</span><span class="n">x_projected</span><span class="p">)</span> <span class="c1"># Apply the second channel projection. x_projected shape: [batch_size, num_patches, embedding_dim].</span> <span class="n">x_projected</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">channel_projection2</span><span class="p">(</span><span class="n">x_spatial</span><span class="p">)</span> <span class="c1"># Add skip connection.</span> <span class="k">return</span> <span class="n">x</span> <span class="o">+</span> <span class="n">x_projected</span> </code></pre></div> <h3 id="build-train-and-evaluate-the-gmlp-model">Build, train, and evaluate the gMLP model</h3> <p>Note that training the model with the current settings on a V100 GPUs takes around 9 seconds per epoch.</p> <div class="codehilite"><pre><span></span><code><span class="n">gmlp_blocks</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span> <span class="p">[</span><span class="n">gMLPLayer</span><span class="p">(</span><span class="n">num_patches</span><span class="p">,</span> <span class="n">embedding_dim</span><span class="p">,</span> <span class="n">dropout_rate</span><span class="p">)</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_blocks</span><span class="p">)]</span> <span class="p">)</span> <span class="n">learning_rate</span> <span class="o">=</span> <span class="mf">0.003</span> <span class="n">gmlp_classifier</span> <span class="o">=</span> <span class="n">build_classifier</span><span class="p">(</span><span class="n">gmlp_blocks</span><span class="p">)</span> <span class="n">history</span> <span class="o">=</span> <span class="n">run_experiment</span><span class="p">(</span><span class="n">gmlp_classifier</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Test accuracy: 17.05% Test top 5 accuracy: 42.57% </code></pre></div> </div> <p>As shown in the <a href="https://arxiv.org/abs/2105.08050">gMLP</a> paper, better results can be achieved by increasing the embedding dimensions, increasing the number of gMLP blocks, and training the model for longer. You may also try to increase the size of the input images and use different patch sizes. Note that, the paper used advanced regularization strategies, such as MixUp and CutMix, as well as AutoAugment.</p> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#image-classification-with-modern-mlp-models'>Image classification with modern MLP models</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#setup'>Setup</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#prepare-the-data'>Prepare the data</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#configure-the-hyperparameters'>Configure the hyperparameters</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#build-a-classification-model'>Build a classification model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#define-an-experiment'>Define an experiment</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#use-data-augmentation'>Use data augmentation</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#implement-patch-extraction-as-a-layer'>Implement patch extraction as a layer</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#implement-position-embedding-as-a-layer'>Implement position embedding as a layer</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#the-mlpmixer-model'>The MLP-Mixer model</a> </div> <div class='k-outline-depth-3'> <a href='#implement-the-mlpmixer-module'>Implement the MLP-Mixer module</a> </div> <div class='k-outline-depth-3'> <a href='#build-train-and-evaluate-the-mlpmixer-model'>Build, train, and evaluate the MLP-Mixer model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#the-fnet-model'>The FNet model</a> </div> <div class='k-outline-depth-3'> <a href='#implement-the-fnet-module'>Implement the FNet module</a> </div> <div class='k-outline-depth-3'> <a href='#build-train-and-evaluate-the-fnet-model'>Build, train, and evaluate the FNet model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#the-gmlp-model'>The gMLP model</a> </div> <div class='k-outline-depth-3'> <a href='#implement-the-gmlp-module'>Implement the gMLP module</a> </div> <div class='k-outline-depth-3'> <a href='#build-train-and-evaluate-the-gmlp-model'>Build, train, and evaluate the gMLP model</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>

Pages: 1 2 3 4 5 6 7 8 9 10