CINXE.COM
Using the Forward-Forward Algorithm for Image Classification
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/vision/forwardforward/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Using the Forward-Forward Algorithm for Image Classification"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Using the Forward-Forward Algorithm for Image Classification"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Using the Forward-Forward Algorithm for Image Classification</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink active" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink2" href="/examples/vision/image_classification_from_scratch/">Image classification from scratch</a> <a class="nav-sublink2" href="/examples/vision/mnist_convnet/">Simple MNIST convnet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_efficientnet_fine_tuning/">Image classification via fine-tuning with EfficientNet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_with_vision_transformer/">Image classification with Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/attention_mil_classification/">Classification using Attention-based Deep Multiple Instance Learning</a> <a class="nav-sublink2" href="/examples/vision/mlp_image_classification/">Image classification with modern MLP models</a> <a class="nav-sublink2" href="/examples/vision/mobilevit/">A mobile-friendly Transformer-based model for image classification</a> <a class="nav-sublink2" href="/examples/vision/xray_classification_with_tpus/">Pneumonia Classification on TPU</a> <a class="nav-sublink2" href="/examples/vision/cct/">Compact Convolutional Transformers</a> <a class="nav-sublink2" href="/examples/vision/convmixer/">Image classification with ConvMixer</a> <a class="nav-sublink2" href="/examples/vision/eanet/">Image classification with EANet (External Attention Transformer)</a> <a class="nav-sublink2" href="/examples/vision/involution/">Involutional neural networks</a> <a class="nav-sublink2" href="/examples/vision/perceiver_image_classification/">Image classification with Perceiver</a> <a class="nav-sublink2" href="/examples/vision/reptile/">Few-Shot learning with Reptile</a> <a class="nav-sublink2" href="/examples/vision/semisupervised_simclr/">Semi-supervised image classification using contrastive pretraining with SimCLR</a> <a class="nav-sublink2" href="/examples/vision/swin_transformers/">Image classification with Swin Transformers</a> <a class="nav-sublink2" href="/examples/vision/vit_small_ds/">Train a Vision Transformer on small datasets</a> <a class="nav-sublink2" href="/examples/vision/shiftvit/">A Vision Transformer without Attention</a> <a class="nav-sublink2" href="/examples/vision/image_classification_using_global_context_vision_transformer/">Image Classification using Global Context Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/oxford_pets_image_segmentation/">Image segmentation with a U-Net-like architecture</a> <a class="nav-sublink2" href="/examples/vision/deeplabv3_plus/">Multiclass semantic segmentation using DeepLabV3+</a> <a class="nav-sublink2" href="/examples/vision/basnet_segmentation/">Highly accurate boundaries segmentation using BASNet</a> <a class="nav-sublink2" href="/examples/vision/fully_convolutional_network/">Image Segmentation using Composable Fully-Convolutional Networks</a> <a class="nav-sublink2" href="/examples/vision/retinanet/">Object Detection with RetinaNet</a> <a class="nav-sublink2" href="/examples/vision/keypoint_detection/">Keypoint Detection with Transfer Learning</a> <a class="nav-sublink2" href="/examples/vision/object_detection_using_vision_transformer/">Object detection with Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/3D_image_classification/">3D image classification from CT scans</a> <a class="nav-sublink2" href="/examples/vision/depth_estimation/">Monocular depth estimation</a> <a class="nav-sublink2" href="/examples/vision/nerf/">3D volumetric rendering with NeRF</a> <a class="nav-sublink2" href="/examples/vision/pointnet_segmentation/">Point cloud segmentation with PointNet</a> <a class="nav-sublink2" href="/examples/vision/pointnet/">Point cloud classification</a> <a class="nav-sublink2" href="/examples/vision/captcha_ocr/">OCR model for reading Captchas</a> <a class="nav-sublink2" href="/examples/vision/handwriting_recognition/">Handwriting recognition</a> <a class="nav-sublink2" href="/examples/vision/autoencoder/">Convolutional autoencoder for image denoising</a> <a class="nav-sublink2" href="/examples/vision/mirnet/">Low-light image enhancement using MIRNet</a> <a class="nav-sublink2" href="/examples/vision/super_resolution_sub_pixel/">Image Super-Resolution using an Efficient Sub-Pixel CNN</a> <a class="nav-sublink2" href="/examples/vision/edsr/">Enhanced Deep Residual Networks for single-image super-resolution</a> <a class="nav-sublink2" href="/examples/vision/zero_dce/">Zero-DCE for low-light image enhancement</a> <a class="nav-sublink2" href="/examples/vision/cutmix/">CutMix data augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/mixup/">MixUp augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/randaugment/">RandAugment for Image Classification for Improved Robustness</a> <a class="nav-sublink2" href="/examples/vision/image_captioning/">Image captioning</a> <a class="nav-sublink2" href="/examples/vision/nl_image_search/">Natural language image search with a Dual Encoder</a> <a class="nav-sublink2" href="/examples/vision/visualizing_what_convnets_learn/">Visualizing what convnets learn</a> <a class="nav-sublink2" href="/examples/vision/integrated_gradients/">Model interpretability with Integrated Gradients</a> <a class="nav-sublink2" href="/examples/vision/probing_vits/">Investigating Vision Transformer representations</a> <a class="nav-sublink2" href="/examples/vision/grad_cam/">Grad-CAM class activation visualization</a> <a class="nav-sublink2" href="/examples/vision/near_dup_search/">Near-duplicate image search</a> <a class="nav-sublink2" href="/examples/vision/semantic_image_clustering/">Semantic Image Clustering</a> <a class="nav-sublink2" href="/examples/vision/siamese_contrastive/">Image similarity estimation using a Siamese Network with a contrastive loss</a> <a class="nav-sublink2" href="/examples/vision/siamese_network/">Image similarity estimation using a Siamese Network with a triplet loss</a> <a class="nav-sublink2" href="/examples/vision/metric_learning/">Metric learning for image similarity search</a> <a class="nav-sublink2" href="/examples/vision/metric_learning_tf_similarity/">Metric learning for image similarity search using TensorFlow Similarity</a> <a class="nav-sublink2" href="/examples/vision/nnclr/">Self-supervised contrastive learning with NNCLR</a> <a class="nav-sublink2" href="/examples/vision/video_classification/">Video Classification with a CNN-RNN Architecture</a> <a class="nav-sublink2" href="/examples/vision/conv_lstm/">Next-Frame Video Prediction with Convolutional LSTMs</a> <a class="nav-sublink2" href="/examples/vision/video_transformers/">Video Classification with Transformers</a> <a class="nav-sublink2" href="/examples/vision/vivit/">Video Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/bit/">Image Classification using BigTransfer (BiT)</a> <a class="nav-sublink2" href="/examples/vision/gradient_centralization/">Gradient Centralization for Better Training Performance</a> <a class="nav-sublink2" href="/examples/vision/token_learner/">Learning to tokenize in Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/knowledge_distillation/">Knowledge Distillation</a> <a class="nav-sublink2" href="/examples/vision/fixres/">FixRes: Fixing train-test resolution discrepancy</a> <a class="nav-sublink2" href="/examples/vision/cait/">Class Attention Image Transformers with LayerScale</a> <a class="nav-sublink2" href="/examples/vision/patch_convnet/">Augmenting convnets with aggregated attention</a> <a class="nav-sublink2" href="/examples/vision/learnable_resizer/">Learning to Resize</a> <a class="nav-sublink2" href="/examples/vision/adamatch/">Semi-supervision and domain adaptation with AdaMatch</a> <a class="nav-sublink2" href="/examples/vision/barlow_twins/">Barlow Twins for Contrastive SSL</a> <a class="nav-sublink2" href="/examples/vision/consistency_training/">Consistency training with supervision</a> <a class="nav-sublink2" href="/examples/vision/deit/">Distilling Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/focal_modulation_network/">Focal Modulation: A replacement for Self-Attention</a> <a class="nav-sublink2 active" href="/examples/vision/forwardforward/">Using the Forward-Forward Algorithm for Image Classification</a> <a class="nav-sublink2" href="/examples/vision/masked_image_modeling/">Masked image modeling with Autoencoders</a> <a class="nav-sublink2" href="/examples/vision/sam/">Segment Anything Model with 🤗Transformers</a> <a class="nav-sublink2" href="/examples/vision/segformer/">Semantic segmentation with SegFormer and Hugging Face Transformers</a> <a class="nav-sublink2" href="/examples/vision/simsiam/">Self-supervised contrastive learning with SimSiam</a> <a class="nav-sublink2" href="/examples/vision/supervised-contrastive-learning/">Supervised Contrastive Learning</a> <a class="nav-sublink2" href="/examples/vision/temporal_latent_bottleneck/">When Recurrence meets Transformers</a> <a class="nav-sublink2" href="/examples/vision/yolov8/">Efficient Object Detection with YOLOV8 and KerasCV</a> <a class="nav-sublink" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparameter Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> <a class="nav-link" href="/keras_cv/" role="tab" aria-selected="">KerasCV: Computer Vision Workflows</a> <a class="nav-link" href="/keras_nlp/" role="tab" aria-selected="">KerasNLP: Natural Language Workflows</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/vision/'>Computer Vision</a> / Using the Forward-Forward Algorithm for Image Classification </div> <div class='k-content'> <h1 id="using-the-forwardforward-algorithm-for-image-classification">Using the Forward-Forward Algorithm for Image Classification</h1> <p><strong>Author:</strong> <a href="https://twitter.com/halcyonrayes">Suvaditya Mukherjee</a><br> <strong>Date created:</strong> 2023/01/08<br> <strong>Last modified:</strong> 2024/09/17<br> <strong>Description:</strong> Training a Dense-layer model using the Forward-Forward algorithm.</p> <div class='example_version_banner keras_2'>ⓘ This example uses Keras 2</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/vision/ipynb/forwardforward.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/examples/vision/forwardforward.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="introduction">Introduction</h2> <p>The following example explores how to use the Forward-Forward algorithm to perform training instead of the traditionally-used method of backpropagation, as proposed by Hinton in <a href="https://www.cs.toronto.edu/~hinton/FFA13.pdf">The Forward-Forward Algorithm: Some Preliminary Investigations</a> (2022).</p> <p>The concept was inspired by the understanding behind <a href="http://www.cs.toronto.edu/~fritz/absps/dbm.pdf">Boltzmann Machines</a>. Backpropagation involves calculating the difference between actual and predicted output via a cost function to adjust network weights. On the other hand, the FF Algorithm suggests the analogy of neurons which get "excited" based on looking at a certain recognized combination of an image and its correct corresponding label.</p> <p>This method takes certain inspiration from the biological learning process that occurs in the cortex. A significant advantage that this method brings is the fact that backpropagation through the network does not need to be performed anymore, and that weight updates are local to the layer itself.</p> <p>As this is yet still an experimental method, it does not yield state-of-the-art results. But with proper tuning, it is supposed to come close to the same. Through this example, we will examine a process that allows us to implement the Forward-Forward algorithm within the layers themselves, instead of the traditional method of relying on the global loss functions and optimizers.</p> <p>The tutorial is structured as follows:</p> <ul> <li>Perform necessary imports</li> <li>Load the <a href="http://yann.lecun.com/exdb/mnist/">MNIST dataset</a></li> <li>Visualize Random samples from the MNIST dataset</li> <li>Define a <code>FFDense</code> Layer to override <code>call</code> and implement a custom <code>forwardforward</code> method which performs weight updates.</li> <li>Define a <code>FFNetwork</code> Layer to override <code>train_step</code>, <code>predict</code> and implement 2 custom functions for per-sample prediction and overlaying labels</li> <li>Convert MNIST from <code>NumPy</code> arrays to <a href="https://www.tensorflow.org/api_docs/python/tf/data/Dataset"><code>tf.data.Dataset</code></a></li> <li>Fit the network</li> <li>Visualize results</li> <li>Perform inference on test samples</li> </ul> <p>As this example requires the customization of certain core functions with <a href="/api/layers/base_layer#layer-class"><code>keras.layers.Layer</code></a> and <code>keras.models.Model</code>, refer to the following resources for a primer on how to do so:</p> <ul> <li><a href="https://keras.io/guides/customizing_what_happens_in_fit">Customizing what happens in <code>model.fit()</code></a></li> <li><a href="https://keras.io/guides/making_new_layers_and_models_via_subclassing">Making new Layers and Models via subclassing</a></li> </ul> <hr /> <h2 id="setup-imports">Setup imports</h2> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">os</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s2">"KERAS_BACKEND"</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"tensorflow"</span> <span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="nn">tf</span> <span class="kn">import</span> <span class="nn">keras</span> <span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">ops</span> <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span> <span class="kn">from</span> <span class="nn">sklearn.metrics</span> <span class="kn">import</span> <span class="n">accuracy_score</span> <span class="kn">import</span> <span class="nn">random</span> <span class="kn">from</span> <span class="nn">tensorflow.compiler.tf2xla.python</span> <span class="kn">import</span> <span class="n">xla</span> </code></pre></div> <hr /> <h2 id="load-the-dataset-and-visualize-the-data">Load the dataset and visualize the data</h2> <p>We use the <code>keras.datasets.mnist.load_data()</code> utility to directly pull the MNIST dataset in the form of <code>NumPy</code> arrays. We then arrange it in the form of the train and test splits.</p> <p>Following loading the dataset, we select 4 random samples from within the training set and visualize them using <code>matplotlib.pyplot</code>.</p> <div class="codehilite"><pre><span></span><code><span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">),</span> <span class="p">(</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">)</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">datasets</span><span class="o">.</span><span class="n">mnist</span><span class="o">.</span><span class="n">load_data</span><span class="p">()</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"4 Random Training samples and labels"</span><span class="p">)</span> <span class="n">idx1</span><span class="p">,</span> <span class="n">idx2</span><span class="p">,</span> <span class="n">idx3</span><span class="p">,</span> <span class="n">idx4</span> <span class="o">=</span> <span class="n">random</span><span class="o">.</span><span class="n">sample</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">x_train</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]),</span> <span class="mi">4</span><span class="p">)</span> <span class="n">img1</span> <span class="o">=</span> <span class="p">(</span><span class="n">x_train</span><span class="p">[</span><span class="n">idx1</span><span class="p">],</span> <span class="n">y_train</span><span class="p">[</span><span class="n">idx1</span><span class="p">])</span> <span class="n">img2</span> <span class="o">=</span> <span class="p">(</span><span class="n">x_train</span><span class="p">[</span><span class="n">idx2</span><span class="p">],</span> <span class="n">y_train</span><span class="p">[</span><span class="n">idx2</span><span class="p">])</span> <span class="n">img3</span> <span class="o">=</span> <span class="p">(</span><span class="n">x_train</span><span class="p">[</span><span class="n">idx3</span><span class="p">],</span> <span class="n">y_train</span><span class="p">[</span><span class="n">idx3</span><span class="p">])</span> <span class="n">img4</span> <span class="o">=</span> <span class="p">(</span><span class="n">x_train</span><span class="p">[</span><span class="n">idx4</span><span class="p">],</span> <span class="n">y_train</span><span class="p">[</span><span class="n">idx4</span><span class="p">])</span> <span class="n">imgs</span> <span class="o">=</span> <span class="p">[</span><span class="n">img1</span><span class="p">,</span> <span class="n">img2</span><span class="p">,</span> <span class="n">img3</span><span class="p">,</span> <span class="n">img4</span><span class="p">]</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span> <span class="k">for</span> <span class="n">idx</span><span class="p">,</span> <span class="n">item</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">imgs</span><span class="p">):</span> <span class="n">image</span><span class="p">,</span> <span class="n">label</span> <span class="o">=</span> <span class="n">item</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">item</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="n">idx</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="n">cmap</span><span class="o">=</span><span class="s2">"gray"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Label : </span><span class="si">{</span><span class="n">label</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>4 Random Training samples and labels </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/forwardforward/forwardforward_5_1.png" /></p> <hr /> <h2 id="define-ffdense-custom-layer">Define <code>FFDense</code> custom layer</h2> <p>In this custom layer, we have a base <a href="/api/layers/core_layers/dense#dense-class"><code>keras.layers.Dense</code></a> object which acts as the base <code>Dense</code> layer within. Since weight updates will happen within the layer itself, we add an <a href="/api/optimizers#optimizer-class"><code>keras.optimizers.Optimizer</code></a> object that is accepted from the user. Here, we use <code>Adam</code> as our optimizer with a rather higher learning rate of <code>0.03</code>.</p> <p>Following the algorithm's specifics, we must set a <code>threshold</code> parameter that will be used to make the positive-negative decision in each prediction. This is set to a default of 2.0. As the epochs are localized to the layer itself, we also set a <code>num_epochs</code> parameter (defaults to 50).</p> <p>We override the <code>call</code> method in order to perform a normalization over the complete input space followed by running it through the base <code>Dense</code> layer as would happen in a normal <code>Dense</code> layer call.</p> <p>We implement the Forward-Forward algorithm which accepts 2 kinds of input tensors, each representing the positive and negative samples respectively. We write a custom training loop here with the use of <code>tf.GradientTape()</code>, within which we calculate a loss per sample by taking the distance of the prediction from the threshold to understand the error and taking its mean to get a <code>mean_loss</code> metric.</p> <p>With the help of <code>tf.GradientTape()</code> we calculate the gradient updates for the trainable base <code>Dense</code> layer and apply them using the layer's local optimizer.</p> <p>Finally, we return the <code>call</code> result as the <code>Dense</code> results of the positive and negative samples while also returning the last <code>mean_loss</code> metric and all the loss values over a certain all-epoch run.</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">FFDense</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""</span> <span class="sd"> A custom ForwardForward-enabled Dense layer. It has an implementation of the</span> <span class="sd"> Forward-Forward network internally for use.</span> <span class="sd"> This layer must be used in conjunction with the `FFNetwork` model.</span> <span class="sd"> """</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> <span class="bp">self</span><span class="p">,</span> <span class="n">units</span><span class="p">,</span> <span class="n">init_optimizer</span><span class="p">,</span> <span class="n">loss_metric</span><span class="p">,</span> <span class="n">num_epochs</span><span class="o">=</span><span class="mi">50</span><span class="p">,</span> <span class="n">use_bias</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">kernel_initializer</span><span class="o">=</span><span class="s2">"glorot_uniform"</span><span class="p">,</span> <span class="n">bias_initializer</span><span class="o">=</span><span class="s2">"zeros"</span><span class="p">,</span> <span class="n">kernel_regularizer</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">bias_regularizer</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">,</span> <span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span> <span class="n">units</span><span class="o">=</span><span class="n">units</span><span class="p">,</span> <span class="n">use_bias</span><span class="o">=</span><span class="n">use_bias</span><span class="p">,</span> <span class="n">kernel_initializer</span><span class="o">=</span><span class="n">kernel_initializer</span><span class="p">,</span> <span class="n">bias_initializer</span><span class="o">=</span><span class="n">bias_initializer</span><span class="p">,</span> <span class="n">kernel_regularizer</span><span class="o">=</span><span class="n">kernel_regularizer</span><span class="p">,</span> <span class="n">bias_regularizer</span><span class="o">=</span><span class="n">bias_regularizer</span><span class="p">,</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">relu</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">ReLU</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span> <span class="o">=</span> <span class="n">init_optimizer</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_metric</span> <span class="o">=</span> <span class="n">loss_metric</span> <span class="bp">self</span><span class="o">.</span><span class="n">threshold</span> <span class="o">=</span> <span class="mf">1.5</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_epochs</span> <span class="o">=</span> <span class="n">num_epochs</span> <span class="c1"># We perform a normalization step before we run the input through the Dense</span> <span class="c1"># layer.</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">norm</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="nb">ord</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="n">x_norm</span> <span class="o">=</span> <span class="n">x_norm</span> <span class="o">+</span> <span class="mf">1e-4</span> <span class="n">x_dir</span> <span class="o">=</span> <span class="n">x</span> <span class="o">/</span> <span class="n">x_norm</span> <span class="n">res</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense</span><span class="p">(</span><span class="n">x_dir</span><span class="p">)</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="n">res</span><span class="p">)</span> <span class="c1"># The Forward-Forward algorithm is below. We first perform the Dense-layer</span> <span class="c1"># operation and then get a Mean Square value for all positive and negative</span> <span class="c1"># samples respectively.</span> <span class="c1"># The custom loss function finds the distance between the Mean-squared</span> <span class="c1"># result and the threshold value we set (a hyperparameter) that will define</span> <span class="c1"># whether the prediction is positive or negative in nature. Once the loss is</span> <span class="c1"># calculated, we get a mean across the entire batch combined and perform a</span> <span class="c1"># gradient calculation and optimization step. This does not technically</span> <span class="c1"># qualify as backpropagation since there is no gradient being</span> <span class="c1"># sent to any previous layer and is completely local in nature.</span> <span class="k">def</span> <span class="nf">forward_forward</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x_pos</span><span class="p">,</span> <span class="n">x_neg</span><span class="p">):</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">num_epochs</span><span class="p">):</span> <span class="k">with</span> <span class="n">tf</span><span class="o">.</span><span class="n">GradientTape</span><span class="p">()</span> <span class="k">as</span> <span class="n">tape</span><span class="p">:</span> <span class="n">g_pos</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">ops</span><span class="o">.</span><span class="n">power</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">call</span><span class="p">(</span><span class="n">x_pos</span><span class="p">),</span> <span class="mi">2</span><span class="p">),</span> <span class="mi">1</span><span class="p">)</span> <span class="n">g_neg</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">ops</span><span class="o">.</span><span class="n">power</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">call</span><span class="p">(</span><span class="n">x_neg</span><span class="p">),</span> <span class="mi">2</span><span class="p">),</span> <span class="mi">1</span><span class="p">)</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">log</span><span class="p">(</span> <span class="mi">1</span> <span class="o">+</span> <span class="n">ops</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span> <span class="n">ops</span><span class="o">.</span><span class="n">concatenate</span><span class="p">(</span> <span class="p">[</span><span class="o">-</span><span class="n">g_pos</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">threshold</span><span class="p">,</span> <span class="n">g_neg</span> <span class="o">-</span> <span class="bp">self</span><span class="o">.</span><span class="n">threshold</span><span class="p">],</span> <span class="mi">0</span> <span class="p">)</span> <span class="p">)</span> <span class="p">)</span> <span class="n">mean_loss</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">ops</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">loss</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"float32"</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_metric</span><span class="o">.</span><span class="n">update_state</span><span class="p">([</span><span class="n">mean_loss</span><span class="p">])</span> <span class="n">gradients</span> <span class="o">=</span> <span class="n">tape</span><span class="o">.</span><span class="n">gradient</span><span class="p">(</span><span class="n">mean_loss</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense</span><span class="o">.</span><span class="n">trainable_weights</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="n">gradients</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dense</span><span class="o">.</span><span class="n">trainable_weights</span><span class="p">))</span> <span class="k">return</span> <span class="p">(</span> <span class="n">ops</span><span class="o">.</span><span class="n">stop_gradient</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">call</span><span class="p">(</span><span class="n">x_pos</span><span class="p">)),</span> <span class="n">ops</span><span class="o">.</span><span class="n">stop_gradient</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">call</span><span class="p">(</span><span class="n">x_neg</span><span class="p">)),</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_metric</span><span class="o">.</span><span class="n">result</span><span class="p">(),</span> <span class="p">)</span> </code></pre></div> <hr /> <h2 id="define-the-ffnetwork-custom-model">Define the <code>FFNetwork</code> Custom Model</h2> <p>With our custom layer defined, we also need to override the <code>train_step</code> method and define a custom <code>keras.models.Model</code> that works with our <code>FFDense</code> layer.</p> <p>For this algorithm, we must 'embed' the labels onto the original image. To do so, we exploit the structure of MNIST images where the top-left 10 pixels are always zeros. We use that as a label space in order to visually one-hot-encode the labels within the image itself. This action is performed by the <code>overlay_y_on_x</code> function.</p> <p>We break down the prediction function with a per-sample prediction function which is then called over the entire test set by the overriden <code>predict()</code> function. The prediction is performed here with the help of measuring the <code>excitation</code> of the neurons per layer for each image. This is then summed over all layers to calculate a network-wide 'goodness score'. The label with the highest 'goodness score' is then chosen as the sample prediction.</p> <p>The <code>train_step</code> function is overriden to act as the main controlling loop for running training on each layer as per the number of epochs per layer.</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">FFNetwork</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""</span> <span class="sd"> A [`keras.Model`](/api/models/model#model-class) that supports a `FFDense` network creation. This model</span> <span class="sd"> can work for any kind of classification task. It has an internal</span> <span class="sd"> implementation with some details specific to the MNIST dataset which can be</span> <span class="sd"> changed as per the use-case.</span> <span class="sd"> """</span> <span class="c1"># Since each layer runs gradient-calculation and optimization locally, each</span> <span class="c1"># layer has its own optimizer that we pass. As a standard choice, we pass</span> <span class="c1"># the `Adam` optimizer with a default learning rate of 0.03 as that was</span> <span class="c1"># found to be the best rate after experimentation.</span> <span class="c1"># Loss is tracked using `loss_var` and `loss_count` variables.</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span> <span class="bp">self</span><span class="p">,</span> <span class="n">dims</span><span class="p">,</span> <span class="n">init_layer_optimizer</span><span class="o">=</span><span class="k">lambda</span><span class="p">:</span> <span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">learning_rate</span><span class="o">=</span><span class="mf">0.03</span><span class="p">),</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">,</span> <span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">init_layer_optimizer</span> <span class="o">=</span> <span class="n">init_layer_optimizer</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_var</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Variable</span><span class="p">(</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">trainable</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"float32"</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_count</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Variable</span><span class="p">(</span><span class="mf">0.0</span><span class="p">,</span> <span class="n">trainable</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"float32"</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_list</span> <span class="o">=</span> <span class="p">[</span><span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">dims</span><span class="p">[</span><span class="mi">0</span><span class="p">],))]</span> <span class="bp">self</span><span class="o">.</span><span class="n">metrics_built</span> <span class="o">=</span> <span class="kc">False</span> <span class="k">for</span> <span class="n">d</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">dims</span><span class="p">)</span> <span class="o">-</span> <span class="mi">1</span><span class="p">):</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_list</span> <span class="o">+=</span> <span class="p">[</span> <span class="n">FFDense</span><span class="p">(</span> <span class="n">dims</span><span class="p">[</span><span class="n">d</span> <span class="o">+</span> <span class="mi">1</span><span class="p">],</span> <span class="n">init_optimizer</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">init_layer_optimizer</span><span class="p">,</span> <span class="n">loss_metric</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">Mean</span><span class="p">(),</span> <span class="p">)</span> <span class="p">]</span> <span class="c1"># This function makes a dynamic change to the image wherein the labels are</span> <span class="c1"># put on top of the original image (for this example, as MNIST has 10</span> <span class="c1"># unique labels, we take the top-left corner's first 10 pixels). This</span> <span class="c1"># function returns the original data tensor with the first 10 pixels being</span> <span class="c1"># a pixel-based one-hot representation of the labels.</span> <span class="nd">@tf</span><span class="o">.</span><span class="n">function</span><span class="p">(</span><span class="n">reduce_retracing</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="k">def</span> <span class="nf">overlay_y_on_x</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">):</span> <span class="n">X_sample</span><span class="p">,</span> <span class="n">y_sample</span> <span class="o">=</span> <span class="n">data</span> <span class="n">max_sample</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">amax</span><span class="p">(</span><span class="n">X_sample</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="n">max_sample</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">max_sample</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"float64"</span><span class="p">)</span> <span class="n">X_zeros</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="mi">10</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"float64"</span><span class="p">)</span> <span class="n">X_update</span> <span class="o">=</span> <span class="n">xla</span><span class="o">.</span><span class="n">dynamic_update_slice</span><span class="p">(</span><span class="n">X_zeros</span><span class="p">,</span> <span class="n">max_sample</span><span class="p">,</span> <span class="p">[</span><span class="n">y_sample</span><span class="p">])</span> <span class="n">X_sample</span> <span class="o">=</span> <span class="n">xla</span><span class="o">.</span><span class="n">dynamic_update_slice</span><span class="p">(</span><span class="n">X_sample</span><span class="p">,</span> <span class="n">X_update</span><span class="p">,</span> <span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="k">return</span> <span class="n">X_sample</span><span class="p">,</span> <span class="n">y_sample</span> <span class="c1"># A custom `predict_one_sample` performs predictions by passing the images</span> <span class="c1"># through the network, measures the results produced by each layer (i.e.</span> <span class="c1"># how high/low the output values are with respect to the set threshold for</span> <span class="c1"># each label) and then simply finding the label with the highest values.</span> <span class="c1"># In such a case, the images are tested for their 'goodness' with all</span> <span class="c1"># labels.</span> <span class="nd">@tf</span><span class="o">.</span><span class="n">function</span><span class="p">(</span><span class="n">reduce_retracing</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="k">def</span> <span class="nf">predict_one_sample</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">):</span> <span class="n">goodness_per_label</span> <span class="o">=</span> <span class="p">[]</span> <span class="n">x</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="p">[</span><span class="n">ops</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">x</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">ops</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">x</span><span class="p">)[</span><span class="mi">1</span><span class="p">]])</span> <span class="k">for</span> <span class="n">label</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">10</span><span class="p">):</span> <span class="n">h</span><span class="p">,</span> <span class="n">label</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">overlay_y_on_x</span><span class="p">(</span><span class="n">data</span><span class="o">=</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">label</span><span class="p">))</span> <span class="n">h</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">h</span><span class="p">,</span> <span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">ops</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">h</span><span class="p">)[</span><span class="mi">0</span><span class="p">]])</span> <span class="n">goodness</span> <span class="o">=</span> <span class="p">[]</span> <span class="k">for</span> <span class="n">layer_idx</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">layer_list</span><span class="p">)):</span> <span class="n">layer</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">layer_list</span><span class="p">[</span><span class="n">layer_idx</span><span class="p">]</span> <span class="n">h</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">h</span><span class="p">)</span> <span class="n">goodness</span> <span class="o">+=</span> <span class="p">[</span><span class="n">ops</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">ops</span><span class="o">.</span><span class="n">power</span><span class="p">(</span><span class="n">h</span><span class="p">,</span> <span class="mi">2</span><span class="p">),</span> <span class="mi">1</span><span class="p">)]</span> <span class="n">goodness_per_label</span> <span class="o">+=</span> <span class="p">[</span><span class="n">ops</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">ops</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">goodness</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="kc">True</span><span class="p">),</span> <span class="mi">1</span><span class="p">)]</span> <span class="n">goodness_per_label</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">concat</span><span class="p">(</span><span class="n">goodness_per_label</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="k">return</span> <span class="n">ops</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">ops</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="n">goodness_per_label</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"float64"</span><span class="p">)</span> <span class="k">def</span> <span class="nf">predict</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="n">data</span> <span class="n">preds</span> <span class="o">=</span> <span class="nb">list</span><span class="p">()</span> <span class="n">preds</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">vectorized_map</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">predict_one_sample</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">preds</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="nb">int</span><span class="p">)</span> <span class="c1"># This custom `train_step` function overrides the internal `train_step`</span> <span class="c1"># implementation. We take all the input image tensors, flatten them and</span> <span class="c1"># subsequently produce positive and negative samples on the images.</span> <span class="c1"># A positive sample is an image that has the right label encoded on it with</span> <span class="c1"># the `overlay_y_on_x` function. A negative sample is an image that has an</span> <span class="c1"># erroneous label present on it.</span> <span class="c1"># With the samples ready, we pass them through each `FFLayer` and perform</span> <span class="c1"># the Forward-Forward computation on it. The returned loss is the final</span> <span class="c1"># loss value over all the layers.</span> <span class="nd">@tf</span><span class="o">.</span><span class="n">function</span><span class="p">(</span><span class="n">jit_compile</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> <span class="k">def</span> <span class="nf">train_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">data</span><span class="p">):</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">data</span> <span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">metrics_built</span><span class="p">:</span> <span class="c1"># build metrics to ensure they can be queried without erroring out.</span> <span class="c1"># We can't update the metrics' state, as we would usually do, since</span> <span class="c1"># we do not perform predictions within the train step</span> <span class="k">for</span> <span class="n">metric</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">metrics</span><span class="p">:</span> <span class="k">if</span> <span class="nb">hasattr</span><span class="p">(</span><span class="n">metric</span><span class="p">,</span> <span class="s2">"build"</span><span class="p">):</span> <span class="n">metric</span><span class="o">.</span><span class="n">build</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">metrics_built</span> <span class="o">=</span> <span class="kc">True</span> <span class="c1"># Flatten op</span> <span class="n">x</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="n">ops</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">x</span><span class="p">)[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">ops</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">x</span><span class="p">)[</span><span class="mi">2</span><span class="p">]])</span> <span class="n">x_pos</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">vectorized_map</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">overlay_y_on_x</span><span class="p">,</span> <span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">))</span> <span class="n">random_y</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="n">y</span><span class="p">)</span> <span class="n">x_neg</span><span class="p">,</span> <span class="n">y</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">map_fn</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">overlay_y_on_x</span><span class="p">,</span> <span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">random_y</span><span class="p">))</span> <span class="n">h_pos</span><span class="p">,</span> <span class="n">h_neg</span> <span class="o">=</span> <span class="n">x_pos</span><span class="p">,</span> <span class="n">x_neg</span> <span class="k">for</span> <span class="n">idx</span><span class="p">,</span> <span class="n">layer</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">layers</span><span class="p">):</span> <span class="k">if</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">FFDense</span><span class="p">):</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Training layer </span><span class="si">{</span><span class="n">idx</span><span class="o">+</span><span class="mi">1</span><span class="si">}</span><span class="s2"> now : "</span><span class="p">)</span> <span class="n">h_pos</span><span class="p">,</span> <span class="n">h_neg</span><span class="p">,</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">layer</span><span class="o">.</span><span class="n">forward_forward</span><span class="p">(</span><span class="n">h_pos</span><span class="p">,</span> <span class="n">h_neg</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_var</span><span class="o">.</span><span class="n">assign_add</span><span class="p">(</span><span class="n">loss</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_count</span><span class="o">.</span><span class="n">assign_add</span><span class="p">(</span><span class="mf">1.0</span><span class="p">)</span> <span class="k">else</span><span class="p">:</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Passing layer </span><span class="si">{</span><span class="n">idx</span><span class="o">+</span><span class="mi">1</span><span class="si">}</span><span class="s2"> now : "</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">mean_res</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">divide</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">loss_var</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_count</span><span class="p">)</span> <span class="k">return</span> <span class="p">{</span><span class="s2">"FinalLoss"</span><span class="p">:</span> <span class="n">mean_res</span><span class="p">}</span> </code></pre></div> <hr /> <h2 id="tfdatadataset">Convert MNIST <code>NumPy</code> arrays to <a href="https://www.tensorflow.org/api_docs/python/tf/data/Dataset"><code>tf.data.Dataset</code></a></h2> <p>We now perform some preliminary processing on the <code>NumPy</code> arrays and then convert them into the <a href="https://www.tensorflow.org/api_docs/python/tf/data/Dataset"><code>tf.data.Dataset</code></a> format which allows for optimized loading.</p> <div class="codehilite"><pre><span></span><code><span class="n">x_train</span> <span class="o">=</span> <span class="n">x_train</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="nb">float</span><span class="p">)</span> <span class="o">/</span> <span class="mi">255</span> <span class="n">x_test</span> <span class="o">=</span> <span class="n">x_test</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="nb">float</span><span class="p">)</span> <span class="o">/</span> <span class="mi">255</span> <span class="n">y_train</span> <span class="o">=</span> <span class="n">y_train</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="nb">int</span><span class="p">)</span> <span class="n">y_test</span> <span class="o">=</span> <span class="n">y_test</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="nb">int</span><span class="p">)</span> <span class="n">train_dataset</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">from_tensor_slices</span><span class="p">((</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">))</span> <span class="n">test_dataset</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">from_tensor_slices</span><span class="p">((</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">))</span> <span class="n">train_dataset</span> <span class="o">=</span> <span class="n">train_dataset</span><span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="mi">60000</span><span class="p">)</span> <span class="n">test_dataset</span> <span class="o">=</span> <span class="n">test_dataset</span><span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="mi">10000</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="fit-the-network-and-visualize-results">Fit the network and visualize results</h2> <p>Having performed all previous set-up, we are now going to run <code>model.fit()</code> and run 250 model epochs, which will perform 50*250 epochs on each layer. We get to see the plotted loss curve as each layer is trained.</p> <div class="codehilite"><pre><span></span><code><span class="n">model</span> <span class="o">=</span> <span class="n">FFNetwork</span><span class="p">(</span><span class="n">dims</span><span class="o">=</span><span class="p">[</span><span class="mi">784</span><span class="p">,</span> <span class="mi">500</span><span class="p">,</span> <span class="mi">500</span><span class="p">])</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">learning_rate</span><span class="o">=</span><span class="mf">0.03</span><span class="p">),</span> <span class="n">loss</span><span class="o">=</span><span class="s2">"mse"</span><span class="p">,</span> <span class="n">jit_compile</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[],</span> <span class="p">)</span> <span class="n">epochs</span> <span class="o">=</span> <span class="mi">250</span> <span class="n">history</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">train_dataset</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="n">epochs</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 1/250 Training layer 1 now : Training layer 2 now : Training layer 1 now : Training layer 2 now : 1/1 ━━━━━━━━━━━━━━━━━━━━ 90s 90s/step - FinalLoss: 0.7247 Epoch 2/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.7089 Epoch 3/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.6978 Epoch 4/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.6827 Epoch 5/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.6644 Epoch 6/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.6462 Epoch 7/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.6290 Epoch 8/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.6131 Epoch 9/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.5986 Epoch 10/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.5853 Epoch 11/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.5731 Epoch 12/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.5621 Epoch 13/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.5519 Epoch 14/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.5425 Epoch 15/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.5338 Epoch 16/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.5259 Epoch 17/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.5186 Epoch 18/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.5117 Epoch 19/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.5052 Epoch 20/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.4992 Epoch 21/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.4935 Epoch 22/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.4883 Epoch 23/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.4833 Epoch 24/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.4786 Epoch 25/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.4741 Epoch 26/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.4698 Epoch 27/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.4658 Epoch 28/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.4620 Epoch 29/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.4584 Epoch 30/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.4550 Epoch 31/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.4517 Epoch 32/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.4486 Epoch 33/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.4456 Epoch 34/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.4429 Epoch 35/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.4401 Epoch 36/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.4375 Epoch 37/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.4350 Epoch 38/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.4325 Epoch 39/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.4302 Epoch 40/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.4279 Epoch 41/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.4258 Epoch 42/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.4236 Epoch 43/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.4216 Epoch 44/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.4197 Epoch 45/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.4177 Epoch 46/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.4159 Epoch 47/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.4141 Epoch 48/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.4124 Epoch 49/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.4107 Epoch 50/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.4090 Epoch 51/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.4074 Epoch 52/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.4059 Epoch 53/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.4044 Epoch 54/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.4030 Epoch 55/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.4016 Epoch 56/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.4002 Epoch 57/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3988 Epoch 58/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3975 Epoch 59/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3962 Epoch 60/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3950 Epoch 61/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3938 Epoch 62/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3926 Epoch 63/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3914 Epoch 64/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3903 Epoch 65/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3891 Epoch 66/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3880 Epoch 67/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3869 Epoch 68/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3859 Epoch 69/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3849 Epoch 70/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3839 Epoch 71/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3829 Epoch 72/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3819 Epoch 73/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3810 Epoch 74/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3801 Epoch 75/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3792 Epoch 76/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3783 Epoch 77/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3774 Epoch 78/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3765 Epoch 79/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3757 Epoch 80/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3748 Epoch 81/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3740 Epoch 82/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3732 Epoch 83/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3723 Epoch 84/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3715 Epoch 85/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3708 Epoch 86/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3700 Epoch 87/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3692 Epoch 88/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3685 Epoch 89/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3677 Epoch 90/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3670 Epoch 91/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3663 Epoch 92/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3656 Epoch 93/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3649 Epoch 94/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3642 Epoch 95/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3635 Epoch 96/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3629 Epoch 97/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3622 Epoch 98/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3616 Epoch 99/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3610 Epoch 100/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3603 Epoch 101/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3597 Epoch 102/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3591 Epoch 103/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3585 Epoch 104/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3579 Epoch 105/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3573 Epoch 106/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3567 Epoch 107/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3562 Epoch 108/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3556 Epoch 109/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3550 Epoch 110/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3545 Epoch 111/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3539 Epoch 112/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3534 Epoch 113/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3529 Epoch 114/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3524 Epoch 115/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3519 Epoch 116/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3513 Epoch 117/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3508 Epoch 118/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3503 Epoch 119/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3498 Epoch 120/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3493 Epoch 121/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3488 Epoch 122/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3484 Epoch 123/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3479 Epoch 124/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3474 Epoch 125/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3470 Epoch 126/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3465 Epoch 127/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3461 Epoch 128/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3456 Epoch 129/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3452 Epoch 130/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3447 Epoch 131/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3443 Epoch 132/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3439 Epoch 133/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3435 Epoch 134/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3430 Epoch 135/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3426 Epoch 136/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3422 Epoch 137/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3418 Epoch 138/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3414 Epoch 139/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3411 Epoch 140/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3407 Epoch 141/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3403 Epoch 142/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3399 Epoch 143/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3395 Epoch 144/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3391 Epoch 145/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3387 Epoch 146/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3384 Epoch 147/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3380 Epoch 148/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3376 Epoch 149/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3373 Epoch 150/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3369 Epoch 151/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3366 Epoch 152/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3362 Epoch 153/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3359 Epoch 154/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3355 Epoch 155/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3352 Epoch 156/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3349 Epoch 157/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3346 Epoch 158/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3342 Epoch 159/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3339 Epoch 160/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3336 Epoch 161/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3333 Epoch 162/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3330 Epoch 163/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3327 Epoch 164/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3324 Epoch 165/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3321 Epoch 166/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3318 Epoch 167/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3315 Epoch 168/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3312 Epoch 169/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3309 Epoch 170/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3306 Epoch 171/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3303 Epoch 172/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3301 Epoch 173/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3298 Epoch 174/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3295 Epoch 175/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3292 Epoch 176/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3289 Epoch 177/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3287 Epoch 178/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3284 Epoch 179/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3281 Epoch 180/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3279 Epoch 181/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3276 Epoch 182/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3273 Epoch 183/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3271 Epoch 184/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3268 Epoch 185/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3266 Epoch 186/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3263 Epoch 187/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3261 Epoch 188/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3259 Epoch 189/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3256 Epoch 190/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3254 Epoch 191/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3251 Epoch 192/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3249 Epoch 193/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3247 Epoch 194/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3244 Epoch 195/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3242 Epoch 196/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3240 Epoch 197/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3238 Epoch 198/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3235 Epoch 199/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3233 Epoch 200/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3231 Epoch 201/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3228 Epoch 202/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3226 Epoch 203/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3224 Epoch 204/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3222 Epoch 205/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3220 Epoch 206/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3217 Epoch 207/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3215 Epoch 208/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3213 Epoch 209/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3211 Epoch 210/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3209 Epoch 211/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3207 Epoch 212/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3205 Epoch 213/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3203 Epoch 214/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3201 Epoch 215/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3199 Epoch 216/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3197 Epoch 217/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3195 Epoch 218/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3193 Epoch 219/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3191 Epoch 220/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3190 Epoch 221/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3188 Epoch 222/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3186 Epoch 223/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3184 Epoch 224/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3182 Epoch 225/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3180 Epoch 226/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3179 Epoch 227/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3177 Epoch 228/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3175 Epoch 229/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3173 Epoch 230/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3171 Epoch 231/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3170 Epoch 232/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3168 Epoch 233/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3166 Epoch 234/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3164 Epoch 235/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3163 Epoch 236/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3161 Epoch 237/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3159 Epoch 238/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3158 Epoch 239/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3156 Epoch 240/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 41s 41s/step - FinalLoss: 0.3154 Epoch 241/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3152 Epoch 242/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3151 Epoch 243/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3149 Epoch 244/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3148 Epoch 245/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3146 Epoch 246/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3145 Epoch 247/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3143 Epoch 248/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3141 Epoch 249/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3140 Epoch 250/250 1/1 ━━━━━━━━━━━━━━━━━━━━ 40s 40s/step - FinalLoss: 0.3138 </code></pre></div> </div> <hr /> <h2 id="perform-inference-and-testing">Perform inference and testing</h2> <p>Having trained the model to a large extent, we now see how it performs on the test set. We calculate the Accuracy Score to understand the results closely.</p> <div class="codehilite"><pre><span></span><code><span class="n">preds</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">ops</span><span class="o">.</span><span class="n">convert_to_tensor</span><span class="p">(</span><span class="n">x_test</span><span class="p">))</span> <span class="n">preds</span> <span class="o">=</span> <span class="n">preds</span><span class="o">.</span><span class="n">reshape</span><span class="p">((</span><span class="n">preds</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">preds</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]))</span> <span class="n">results</span> <span class="o">=</span> <span class="n">accuracy_score</span><span class="p">(</span><span class="n">preds</span><span class="p">,</span> <span class="n">y_test</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Test Accuracy score : </span><span class="si">{</span><span class="n">results</span><span class="o">*</span><span class="mi">100</span><span class="si">}</span><span class="s2">%"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">history</span><span class="o">.</span><span class="n">history</span><span class="p">[</span><span class="s2">"FinalLoss"</span><span class="p">])),</span> <span class="n">history</span><span class="o">.</span><span class="n">history</span><span class="p">[</span><span class="s2">"FinalLoss"</span><span class="p">])</span> <span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="s2">"Loss over training"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Test Accuracy score : 97.56% </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/forwardforward/forwardforward_15_1.png" /></p> <hr /> <h2 id="conclusion">Conclusion</h2> <p>This example has hereby demonstrated how the Forward-Forward algorithm works using the TensorFlow and Keras packages. While the investigation results presented by Prof. Hinton in their paper are currently still limited to smaller models and datasets like MNIST and Fashion-MNIST, subsequent results on larger models like LLMs are expected in future papers.</p> <p>Through the paper, Prof. Hinton has reported results of 1.36% test accuracy error with a 2000-units, 4 hidden-layer, fully-connected network run over 60 epochs (while mentioning that backpropagation takes only 20 epochs to achieve similar performance). Another run of doubling the learning rate and training for 40 epochs yields a slightly worse error rate of 1.46%</p> <p>The current example does not yield state-of-the-art results. But with proper tuning of the Learning Rate, model architecture (number of units in <code>Dense</code> layers, kernel activations, initializations, regularization etc.), the results can be improved to match the claims of the paper.</p> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#using-the-forwardforward-algorithm-for-image-classification'>Using the Forward-Forward Algorithm for Image Classification</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#setup-imports'>Setup imports</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#load-the-dataset-and-visualize-the-data'>Load the dataset and visualize the data</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#define-ffdense-custom-layer'>Define <code>FFDense</code> custom layer</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#define-the-ffnetwork-custom-model'>Define the <code>FFNetwork</code> Custom Model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#convert-mnist-numpy-arrays-to-tfdatadataset'>Convert MNIST <code>NumPy</code> arrays to <code>tf.data.Dataset</code></a> </div> <div class='k-outline-depth-2'> ◆ <a href='#fit-the-network-and-visualize-results'>Fit the network and visualize results</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#perform-inference-and-testing'>Perform inference and testing</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#conclusion'>Conclusion</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>