CINXE.COM
Next-Frame Video Prediction with Convolutional LSTMs
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/vision/conv_lstm/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Next-Frame Video Prediction with Convolutional LSTMs"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Next-Frame Video Prediction with Convolutional LSTMs"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Next-Frame Video Prediction with Convolutional LSTMs</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink active" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink2" href="/examples/vision/image_classification_from_scratch/">Image classification from scratch</a> <a class="nav-sublink2" href="/examples/vision/mnist_convnet/">Simple MNIST convnet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_efficientnet_fine_tuning/">Image classification via fine-tuning with EfficientNet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_with_vision_transformer/">Image classification with Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/attention_mil_classification/">Classification using Attention-based Deep Multiple Instance Learning</a> <a class="nav-sublink2" href="/examples/vision/mlp_image_classification/">Image classification with modern MLP models</a> <a class="nav-sublink2" href="/examples/vision/mobilevit/">A mobile-friendly Transformer-based model for image classification</a> <a class="nav-sublink2" href="/examples/vision/xray_classification_with_tpus/">Pneumonia Classification on TPU</a> <a class="nav-sublink2" href="/examples/vision/cct/">Compact Convolutional Transformers</a> <a class="nav-sublink2" href="/examples/vision/convmixer/">Image classification with ConvMixer</a> <a class="nav-sublink2" href="/examples/vision/eanet/">Image classification with EANet (External Attention Transformer)</a> <a class="nav-sublink2" href="/examples/vision/involution/">Involutional neural networks</a> <a class="nav-sublink2" href="/examples/vision/perceiver_image_classification/">Image classification with Perceiver</a> <a class="nav-sublink2" href="/examples/vision/reptile/">Few-Shot learning with Reptile</a> <a class="nav-sublink2" href="/examples/vision/semisupervised_simclr/">Semi-supervised image classification using contrastive pretraining with SimCLR</a> <a class="nav-sublink2" href="/examples/vision/swin_transformers/">Image classification with Swin Transformers</a> <a class="nav-sublink2" href="/examples/vision/vit_small_ds/">Train a Vision Transformer on small datasets</a> <a class="nav-sublink2" href="/examples/vision/shiftvit/">A Vision Transformer without Attention</a> <a class="nav-sublink2" href="/examples/vision/image_classification_using_global_context_vision_transformer/">Image Classification using Global Context Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/oxford_pets_image_segmentation/">Image segmentation with a U-Net-like architecture</a> <a class="nav-sublink2" href="/examples/vision/deeplabv3_plus/">Multiclass semantic segmentation using DeepLabV3+</a> <a class="nav-sublink2" href="/examples/vision/basnet_segmentation/">Highly accurate boundaries segmentation using BASNet</a> <a class="nav-sublink2" href="/examples/vision/fully_convolutional_network/">Image Segmentation using Composable Fully-Convolutional Networks</a> <a class="nav-sublink2" href="/examples/vision/retinanet/">Object Detection with RetinaNet</a> <a class="nav-sublink2" href="/examples/vision/keypoint_detection/">Keypoint Detection with Transfer Learning</a> <a class="nav-sublink2" href="/examples/vision/object_detection_using_vision_transformer/">Object detection with Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/3D_image_classification/">3D image classification from CT scans</a> <a class="nav-sublink2" href="/examples/vision/depth_estimation/">Monocular depth estimation</a> <a class="nav-sublink2" href="/examples/vision/nerf/">3D volumetric rendering with NeRF</a> <a class="nav-sublink2" href="/examples/vision/pointnet_segmentation/">Point cloud segmentation with PointNet</a> <a class="nav-sublink2" href="/examples/vision/pointnet/">Point cloud classification</a> <a class="nav-sublink2" href="/examples/vision/captcha_ocr/">OCR model for reading Captchas</a> <a class="nav-sublink2" href="/examples/vision/handwriting_recognition/">Handwriting recognition</a> <a class="nav-sublink2" href="/examples/vision/autoencoder/">Convolutional autoencoder for image denoising</a> <a class="nav-sublink2" href="/examples/vision/mirnet/">Low-light image enhancement using MIRNet</a> <a class="nav-sublink2" href="/examples/vision/super_resolution_sub_pixel/">Image Super-Resolution using an Efficient Sub-Pixel CNN</a> <a class="nav-sublink2" href="/examples/vision/edsr/">Enhanced Deep Residual Networks for single-image super-resolution</a> <a class="nav-sublink2" href="/examples/vision/zero_dce/">Zero-DCE for low-light image enhancement</a> <a class="nav-sublink2" href="/examples/vision/cutmix/">CutMix data augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/mixup/">MixUp augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/randaugment/">RandAugment for Image Classification for Improved Robustness</a> <a class="nav-sublink2" href="/examples/vision/image_captioning/">Image captioning</a> <a class="nav-sublink2" href="/examples/vision/nl_image_search/">Natural language image search with a Dual Encoder</a> <a class="nav-sublink2" href="/examples/vision/visualizing_what_convnets_learn/">Visualizing what convnets learn</a> <a class="nav-sublink2" href="/examples/vision/integrated_gradients/">Model interpretability with Integrated Gradients</a> <a class="nav-sublink2" href="/examples/vision/probing_vits/">Investigating Vision Transformer representations</a> <a class="nav-sublink2" href="/examples/vision/grad_cam/">Grad-CAM class activation visualization</a> <a class="nav-sublink2" href="/examples/vision/near_dup_search/">Near-duplicate image search</a> <a class="nav-sublink2" href="/examples/vision/semantic_image_clustering/">Semantic Image Clustering</a> <a class="nav-sublink2" href="/examples/vision/siamese_contrastive/">Image similarity estimation using a Siamese Network with a contrastive loss</a> <a class="nav-sublink2" href="/examples/vision/siamese_network/">Image similarity estimation using a Siamese Network with a triplet loss</a> <a class="nav-sublink2" href="/examples/vision/metric_learning/">Metric learning for image similarity search</a> <a class="nav-sublink2" href="/examples/vision/metric_learning_tf_similarity/">Metric learning for image similarity search using TensorFlow Similarity</a> <a class="nav-sublink2" href="/examples/vision/nnclr/">Self-supervised contrastive learning with NNCLR</a> <a class="nav-sublink2" href="/examples/vision/video_classification/">Video Classification with a CNN-RNN Architecture</a> <a class="nav-sublink2 active" href="/examples/vision/conv_lstm/">Next-Frame Video Prediction with Convolutional LSTMs</a> <a class="nav-sublink2" href="/examples/vision/video_transformers/">Video Classification with Transformers</a> <a class="nav-sublink2" href="/examples/vision/vivit/">Video Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/bit/">Image Classification using BigTransfer (BiT)</a> <a class="nav-sublink2" href="/examples/vision/gradient_centralization/">Gradient Centralization for Better Training Performance</a> <a class="nav-sublink2" href="/examples/vision/token_learner/">Learning to tokenize in Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/knowledge_distillation/">Knowledge Distillation</a> <a class="nav-sublink2" href="/examples/vision/fixres/">FixRes: Fixing train-test resolution discrepancy</a> <a class="nav-sublink2" href="/examples/vision/cait/">Class Attention Image Transformers with LayerScale</a> <a class="nav-sublink2" href="/examples/vision/patch_convnet/">Augmenting convnets with aggregated attention</a> <a class="nav-sublink2" href="/examples/vision/learnable_resizer/">Learning to Resize</a> <a class="nav-sublink2" href="/examples/vision/adamatch/">Semi-supervision and domain adaptation with AdaMatch</a> <a class="nav-sublink2" href="/examples/vision/barlow_twins/">Barlow Twins for Contrastive SSL</a> <a class="nav-sublink2" href="/examples/vision/consistency_training/">Consistency training with supervision</a> <a class="nav-sublink2" href="/examples/vision/deit/">Distilling Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/focal_modulation_network/">Focal Modulation: A replacement for Self-Attention</a> <a class="nav-sublink2" href="/examples/vision/forwardforward/">Using the Forward-Forward Algorithm for Image Classification</a> <a class="nav-sublink2" href="/examples/vision/masked_image_modeling/">Masked image modeling with Autoencoders</a> <a class="nav-sublink2" href="/examples/vision/sam/">Segment Anything Model with 🤗Transformers</a> <a class="nav-sublink2" href="/examples/vision/segformer/">Semantic segmentation with SegFormer and Hugging Face Transformers</a> <a class="nav-sublink2" href="/examples/vision/simsiam/">Self-supervised contrastive learning with SimSiam</a> <a class="nav-sublink2" href="/examples/vision/supervised-contrastive-learning/">Supervised Contrastive Learning</a> <a class="nav-sublink2" href="/examples/vision/temporal_latent_bottleneck/">When Recurrence meets Transformers</a> <a class="nav-sublink2" href="/examples/vision/yolov8/">Efficient Object Detection with YOLOV8 and KerasCV</a> <a class="nav-sublink" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparam Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/vision/'>Computer Vision</a> / Next-Frame Video Prediction with Convolutional LSTMs </div> <div class='k-content'> <h1 id="nextframe-video-prediction-with-convolutional-lstms">Next-Frame Video Prediction with Convolutional LSTMs</h1> <p><strong>Author:</strong> <a href="https://github.com/amogh7joshi">Amogh Joshi</a><br> <strong>Date created:</strong> 2021/06/02<br> <strong>Last modified:</strong> 2023/11/10<br> <strong>Description:</strong> How to build and train a convolutional LSTM model for next-frame video prediction.</p> <div class='example_version_banner keras_3'>ⓘ This example uses Keras 3</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/vision/ipynb/conv_lstm.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/examples/vision/conv_lstm.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="introduction">Introduction</h2> <p>The <a href="https://papers.nips.cc/paper/2015/file/07563a3fe3bbe7e3ba84431ad9d055af-Paper.pdf">Convolutional LSTM</a> architectures bring together time series processing and computer vision by introducing a convolutional recurrent cell in a LSTM layer. In this example, we will explore the Convolutional LSTM model in an application to next-frame prediction, the process of predicting what video frames come next given a series of past frames.</p> <hr /> <h2 id="setup">Setup</h2> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span> <span class="kn">import</span> <span class="nn">keras</span> <span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">layers</span> <span class="kn">import</span> <span class="nn">io</span> <span class="kn">import</span> <span class="nn">imageio</span> <span class="kn">from</span> <span class="nn">IPython.display</span> <span class="kn">import</span> <span class="n">Image</span><span class="p">,</span> <span class="n">display</span> <span class="kn">from</span> <span class="nn">ipywidgets</span> <span class="kn">import</span> <span class="n">widgets</span><span class="p">,</span> <span class="n">Layout</span><span class="p">,</span> <span class="n">HBox</span> </code></pre></div> <hr /> <h2 id="dataset-construction">Dataset Construction</h2> <p>For this example, we will be using the <a href="http://www.cs.toronto.edu/~nitish/unsupervised_video/">Moving MNIST</a> dataset.</p> <p>We will download the dataset and then construct and preprocess training and validation sets.</p> <p>For next-frame prediction, our model will be using a previous frame, which we'll call <code>f_n</code>, to predict a new frame, called <code>f_(n + 1)</code>. To allow the model to create these predictions, we'll need to process the data such that we have "shifted" inputs and outputs, where the input data is frame <code>x_n</code>, being used to predict frame <code>y_(n + 1)</code>.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Download and load the dataset.</span> <span class="n">fpath</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">get_file</span><span class="p">(</span> <span class="s2">"moving_mnist.npy"</span><span class="p">,</span> <span class="s2">"http://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy"</span><span class="p">,</span> <span class="p">)</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">fpath</span><span class="p">)</span> <span class="c1"># Swap the axes representing the number of frames and number of data samples.</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">swapaxes</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span> <span class="c1"># We'll pick out 1000 of the 10000 total examples and use those.</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">dataset</span><span class="p">[:</span><span class="mi">1000</span><span class="p">,</span> <span class="o">...</span><span class="p">]</span> <span class="c1"># Add a channel dimension since the images are grayscale.</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># Split into train and validation sets using indexing to optimize memory.</span> <span class="n">indexes</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">dataset</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="n">indexes</span><span class="p">)</span> <span class="n">train_index</span> <span class="o">=</span> <span class="n">indexes</span><span class="p">[:</span> <span class="nb">int</span><span class="p">(</span><span class="mf">0.9</span> <span class="o">*</span> <span class="n">dataset</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])]</span> <span class="n">val_index</span> <span class="o">=</span> <span class="n">indexes</span><span class="p">[</span><span class="nb">int</span><span class="p">(</span><span class="mf">0.9</span> <span class="o">*</span> <span class="n">dataset</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span> <span class="p">:]</span> <span class="n">train_dataset</span> <span class="o">=</span> <span class="n">dataset</span><span class="p">[</span><span class="n">train_index</span><span class="p">]</span> <span class="n">val_dataset</span> <span class="o">=</span> <span class="n">dataset</span><span class="p">[</span><span class="n">val_index</span><span class="p">]</span> <span class="c1"># Normalize the data to the 0-1 range.</span> <span class="n">train_dataset</span> <span class="o">=</span> <span class="n">train_dataset</span> <span class="o">/</span> <span class="mi">255</span> <span class="n">val_dataset</span> <span class="o">=</span> <span class="n">val_dataset</span> <span class="o">/</span> <span class="mi">255</span> <span class="c1"># We'll define a helper function to shift the frames, where</span> <span class="c1"># `x` is frames 0 to n - 1, and `y` is frames 1 to n.</span> <span class="k">def</span> <span class="nf">create_shifted_frames</span><span class="p">(</span><span class="n">data</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="n">data</span><span class="p">[:,</span> <span class="mi">0</span> <span class="p">:</span> <span class="n">data</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">-</span> <span class="mi">1</span><span class="p">,</span> <span class="p">:,</span> <span class="p">:]</span> <span class="n">y</span> <span class="o">=</span> <span class="n">data</span><span class="p">[:,</span> <span class="mi">1</span> <span class="p">:</span> <span class="n">data</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="p">:,</span> <span class="p">:]</span> <span class="k">return</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span> <span class="c1"># Apply the processing function to the datasets.</span> <span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span> <span class="o">=</span> <span class="n">create_shifted_frames</span><span class="p">(</span><span class="n">train_dataset</span><span class="p">)</span> <span class="n">x_val</span><span class="p">,</span> <span class="n">y_val</span> <span class="o">=</span> <span class="n">create_shifted_frames</span><span class="p">(</span><span class="n">val_dataset</span><span class="p">)</span> <span class="c1"># Inspect the dataset.</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Training Dataset Shapes: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">x_train</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">+</span> <span class="s2">", "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">y_train</span><span class="o">.</span><span class="n">shape</span><span class="p">))</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Validation Dataset Shapes: "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">x_val</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="o">+</span> <span class="s2">", "</span> <span class="o">+</span> <span class="nb">str</span><span class="p">(</span><span class="n">y_val</span><span class="o">.</span><span class="n">shape</span><span class="p">))</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Downloading data from http://www.cs.toronto.edu/~nitish/unsupervised_video/mnist_test_seq.npy 819200096/819200096 ━━━━━━━━━━━━━━━━━━━━ 116s 0us/step Training Dataset Shapes: (900, 19, 64, 64, 1), (900, 19, 64, 64, 1) Validation Dataset Shapes: (100, 19, 64, 64, 1), (100, 19, 64, 64, 1) </code></pre></div> </div> <hr /> <h2 id="data-visualization">Data Visualization</h2> <p>Our data consists of sequences of frames, each of which are used to predict the upcoming frame. Let's take a look at some of these sequential frames.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Construct a figure on which we will visualize the images.</span> <span class="n">fig</span><span class="p">,</span> <span class="n">axes</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">(</span><span class="mi">4</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">8</span><span class="p">))</span> <span class="c1"># Plot each of the sequential images for one random data example.</span> <span class="n">data_choice</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">train_dataset</span><span class="p">)),</span> <span class="n">size</span><span class="o">=</span><span class="mi">1</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span> <span class="k">for</span> <span class="n">idx</span><span class="p">,</span> <span class="n">ax</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">axes</span><span class="o">.</span><span class="n">flat</span><span class="p">):</span> <span class="n">ax</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="n">train_dataset</span><span class="p">[</span><span class="n">data_choice</span><span class="p">][</span><span class="n">idx</span><span class="p">]),</span> <span class="n">cmap</span><span class="o">=</span><span class="s2">"gray"</span><span class="p">)</span> <span class="n">ax</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Frame </span><span class="si">{</span><span class="n">idx</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="mi">1</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> <span class="n">ax</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">"off"</span><span class="p">)</span> <span class="c1"># Print information and display the figure.</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Displaying frames for example </span><span class="si">{</span><span class="n">data_choice</span><span class="si">}</span><span class="s2">."</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Displaying frames for example 95. </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/conv_lstm/conv_lstm_7_1.png" /></p> <hr /> <h2 id="model-construction">Model Construction</h2> <p>To build a Convolutional LSTM model, we will use the <code>ConvLSTM2D</code> layer, which will accept inputs of shape <code>(batch_size, num_frames, width, height, channels)</code>, and return a prediction movie of the same shape.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Construct the input layer with no definite frame size.</span> <span class="n">inp</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="kc">None</span><span class="p">,</span> <span class="o">*</span><span class="n">x_train</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">2</span><span class="p">:]))</span> <span class="c1"># We will construct 3 `ConvLSTM2D` layers with batch normalization,</span> <span class="c1"># followed by a `Conv3D` layer for the spatiotemporal outputs.</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">ConvLSTM2D</span><span class="p">(</span> <span class="n">filters</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">5</span><span class="p">),</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"same"</span><span class="p">,</span> <span class="n">return_sequences</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">,</span> <span class="p">)(</span><span class="n">inp</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">BatchNormalization</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">ConvLSTM2D</span><span class="p">(</span> <span class="n">filters</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"same"</span><span class="p">,</span> <span class="n">return_sequences</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">,</span> <span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">BatchNormalization</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">ConvLSTM2D</span><span class="p">(</span> <span class="n">filters</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"same"</span><span class="p">,</span> <span class="n">return_sequences</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">,</span> <span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv3D</span><span class="p">(</span> <span class="n">filters</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">kernel_size</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"sigmoid"</span><span class="p">,</span> <span class="n">padding</span><span class="o">=</span><span class="s2">"same"</span> <span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="c1"># Next, we will build the complete model and compile it.</span> <span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inp</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">loss</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">binary_crossentropy</span><span class="p">,</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">Adam</span><span class="p">(),</span> <span class="p">)</span> </code></pre></div> <hr /> <h2 id="model-training">Model Training</h2> <p>With our model and data constructed, we can now train the model.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Define some callbacks to improve training.</span> <span class="n">early_stopping</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">callbacks</span><span class="o">.</span><span class="n">EarlyStopping</span><span class="p">(</span><span class="n">monitor</span><span class="o">=</span><span class="s2">"val_loss"</span><span class="p">,</span> <span class="n">patience</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span> <span class="n">reduce_lr</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">callbacks</span><span class="o">.</span><span class="n">ReduceLROnPlateau</span><span class="p">(</span><span class="n">monitor</span><span class="o">=</span><span class="s2">"val_loss"</span><span class="p">,</span> <span class="n">patience</span><span class="o">=</span><span class="mi">5</span><span class="p">)</span> <span class="c1"># Define modifiable training hyperparameters.</span> <span class="n">epochs</span> <span class="o">=</span> <span class="mi">20</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="mi">5</span> <span class="c1"># Fit the model to the training data.</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span> <span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="n">epochs</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="p">(</span><span class="n">x_val</span><span class="p">,</span> <span class="n">y_val</span><span class="p">),</span> <span class="n">callbacks</span><span class="o">=</span><span class="p">[</span><span class="n">early_stopping</span><span class="p">,</span> <span class="n">reduce_lr</span><span class="p">],</span> <span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 1/20 180/180 ━━━━━━━━━━━━━━━━━━━━ 50s 226ms/step - loss: 0.1510 - val_loss: 0.2966 - learning_rate: 0.0010 Epoch 2/20 180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0287 - val_loss: 0.1766 - learning_rate: 0.0010 Epoch 3/20 180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0269 - val_loss: 0.0661 - learning_rate: 0.0010 Epoch 4/20 180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0264 - val_loss: 0.0279 - learning_rate: 0.0010 Epoch 5/20 180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0258 - val_loss: 0.0254 - learning_rate: 0.0010 Epoch 6/20 180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0256 - val_loss: 0.0253 - learning_rate: 0.0010 Epoch 7/20 180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0251 - val_loss: 0.0248 - learning_rate: 0.0010 Epoch 8/20 180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0251 - val_loss: 0.0251 - learning_rate: 0.0010 Epoch 9/20 180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0247 - val_loss: 0.0243 - learning_rate: 0.0010 Epoch 10/20 180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0246 - val_loss: 0.0246 - learning_rate: 0.0010 Epoch 11/20 180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0245 - val_loss: 0.0247 - learning_rate: 0.0010 Epoch 12/20 180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0241 - val_loss: 0.0243 - learning_rate: 0.0010 Epoch 13/20 180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0244 - val_loss: 0.0245 - learning_rate: 0.0010 Epoch 14/20 180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0241 - val_loss: 0.0241 - learning_rate: 0.0010 Epoch 15/20 180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0243 - val_loss: 0.0241 - learning_rate: 0.0010 Epoch 16/20 180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0242 - val_loss: 0.0242 - learning_rate: 0.0010 Epoch 17/20 180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0240 - val_loss: 0.0240 - learning_rate: 0.0010 Epoch 18/20 180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0240 - val_loss: 0.0243 - learning_rate: 0.0010 Epoch 19/20 180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0240 - val_loss: 0.0244 - learning_rate: 0.0010 Epoch 20/20 180/180 ━━━━━━━━━━━━━━━━━━━━ 40s 219ms/step - loss: 0.0237 - val_loss: 0.0238 - learning_rate: 1.0000e-04 <keras.src.callbacks.history.History at 0x7ff294f9c340> </code></pre></div> </div> <hr /> <h2 id="frame-prediction-visualizations">Frame Prediction Visualizations</h2> <p>With our model now constructed and trained, we can generate some example frame predictions based on a new video.</p> <p>We'll pick a random example from the validation set and then choose the first ten frames from them. From there, we can allow the model to predict 10 new frames, which we can compare to the ground truth frame predictions.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Select a random example from the validation dataset.</span> <span class="n">example</span> <span class="o">=</span> <span class="n">val_dataset</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">val_dataset</span><span class="p">)),</span> <span class="n">size</span><span class="o">=</span><span class="mi">1</span><span class="p">)[</span><span class="mi">0</span><span class="p">]]</span> <span class="c1"># Pick the first/last ten frames from the example.</span> <span class="n">frames</span> <span class="o">=</span> <span class="n">example</span><span class="p">[:</span><span class="mi">10</span><span class="p">,</span> <span class="o">...</span><span class="p">]</span> <span class="n">original_frames</span> <span class="o">=</span> <span class="n">example</span><span class="p">[</span><span class="mi">10</span><span class="p">:,</span> <span class="o">...</span><span class="p">]</span> <span class="c1"># Predict a new set of 10 frames.</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">10</span><span class="p">):</span> <span class="c1"># Extract the model's prediction and post-process it.</span> <span class="n">new_prediction</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">frames</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">))</span> <span class="n">new_prediction</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="n">new_prediction</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="n">predicted_frame</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">new_prediction</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">...</span><span class="p">],</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="c1"># Extend the set of prediction frames.</span> <span class="n">frames</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">((</span><span class="n">frames</span><span class="p">,</span> <span class="n">predicted_frame</span><span class="p">),</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="c1"># Construct a figure for the original and new frames.</span> <span class="n">fig</span><span class="p">,</span> <span class="n">axes</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="mi">10</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">20</span><span class="p">,</span> <span class="mi">4</span><span class="p">))</span> <span class="c1"># Plot the original frames.</span> <span class="k">for</span> <span class="n">idx</span><span class="p">,</span> <span class="n">ax</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">axes</span><span class="p">[</span><span class="mi">0</span><span class="p">]):</span> <span class="n">ax</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="n">original_frames</span><span class="p">[</span><span class="n">idx</span><span class="p">]),</span> <span class="n">cmap</span><span class="o">=</span><span class="s2">"gray"</span><span class="p">)</span> <span class="n">ax</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Frame </span><span class="si">{</span><span class="n">idx</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="mi">11</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> <span class="n">ax</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">"off"</span><span class="p">)</span> <span class="c1"># Plot the new frames.</span> <span class="n">new_frames</span> <span class="o">=</span> <span class="n">frames</span><span class="p">[</span><span class="mi">10</span><span class="p">:,</span> <span class="o">...</span><span class="p">]</span> <span class="k">for</span> <span class="n">idx</span><span class="p">,</span> <span class="n">ax</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">axes</span><span class="p">[</span><span class="mi">1</span><span class="p">]):</span> <span class="n">ax</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="n">new_frames</span><span class="p">[</span><span class="n">idx</span><span class="p">]),</span> <span class="n">cmap</span><span class="o">=</span><span class="s2">"gray"</span><span class="p">)</span> <span class="n">ax</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Frame </span><span class="si">{</span><span class="n">idx</span><span class="w"> </span><span class="o">+</span><span class="w"> </span><span class="mi">11</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> <span class="n">ax</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">"off"</span><span class="p">)</span> <span class="c1"># Display the figure.</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 1/1 ━━━━━━━━━━━━━━━━━━━━ 2s 2s/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 800ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 805ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 790ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 821ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 824ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 928ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 813ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 810ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 814ms/step </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/conv_lstm/conv_lstm_13_1.png" /></p> <hr /> <h2 id="predicted-videos">Predicted Videos</h2> <p>Finally, we'll pick a few examples from the validation set and construct some GIFs with them to see the model's predicted videos.</p> <p>You can use the trained model hosted on <a href="https://huggingface.co/keras-io/conv-lstm">Hugging Face Hub</a> and try the demo on <a href="https://huggingface.co/spaces/keras-io/conv-lstm">Hugging Face Spaces</a>.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Select a few random examples from the dataset.</span> <span class="n">examples</span> <span class="o">=</span> <span class="n">val_dataset</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span><span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">val_dataset</span><span class="p">)),</span> <span class="n">size</span><span class="o">=</span><span class="mi">5</span><span class="p">)]</span> <span class="c1"># Iterate over the examples and predict the frames.</span> <span class="n">predicted_videos</span> <span class="o">=</span> <span class="p">[]</span> <span class="k">for</span> <span class="n">example</span> <span class="ow">in</span> <span class="n">examples</span><span class="p">:</span> <span class="c1"># Pick the first/last ten frames from the example.</span> <span class="n">frames</span> <span class="o">=</span> <span class="n">example</span><span class="p">[:</span><span class="mi">10</span><span class="p">,</span> <span class="o">...</span><span class="p">]</span> <span class="n">original_frames</span> <span class="o">=</span> <span class="n">example</span><span class="p">[</span><span class="mi">10</span><span class="p">:,</span> <span class="o">...</span><span class="p">]</span> <span class="n">new_predictions</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="o">*</span><span class="n">frames</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">))</span> <span class="c1"># Predict a new set of 10 frames.</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">10</span><span class="p">):</span> <span class="c1"># Extract the model's prediction and post-process it.</span> <span class="n">frames</span> <span class="o">=</span> <span class="n">example</span><span class="p">[:</span> <span class="mi">10</span> <span class="o">+</span> <span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="o">...</span><span class="p">]</span> <span class="n">new_prediction</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">frames</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">))</span> <span class="n">new_prediction</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="n">new_prediction</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="n">predicted_frame</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">new_prediction</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="o">...</span><span class="p">],</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="c1"># Extend the set of prediction frames.</span> <span class="n">new_predictions</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">=</span> <span class="n">predicted_frame</span> <span class="c1"># Create and save GIFs for each of the ground truth/prediction images.</span> <span class="k">for</span> <span class="n">frame_set</span> <span class="ow">in</span> <span class="p">[</span><span class="n">original_frames</span><span class="p">,</span> <span class="n">new_predictions</span><span class="p">]:</span> <span class="c1"># Construct a GIF from the selected video frames.</span> <span class="n">current_frames</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="n">frame_set</span><span class="p">)</span> <span class="n">current_frames</span> <span class="o">=</span> <span class="n">current_frames</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="n">np</span><span class="o">.</span><span class="n">newaxis</span><span class="p">]</span> <span class="o">*</span> <span class="n">np</span><span class="o">.</span><span class="n">ones</span><span class="p">(</span><span class="mi">3</span><span class="p">)</span> <span class="n">current_frames</span> <span class="o">=</span> <span class="p">(</span><span class="n">current_frames</span> <span class="o">*</span> <span class="mi">255</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">)</span> <span class="n">current_frames</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">current_frames</span><span class="p">)</span> <span class="c1"># Construct a GIF from the frames.</span> <span class="k">with</span> <span class="n">io</span><span class="o">.</span><span class="n">BytesIO</span><span class="p">()</span> <span class="k">as</span> <span class="n">gif</span><span class="p">:</span> <span class="n">imageio</span><span class="o">.</span><span class="n">mimsave</span><span class="p">(</span><span class="n">gif</span><span class="p">,</span> <span class="n">current_frames</span><span class="p">,</span> <span class="s2">"GIF"</span><span class="p">,</span> <span class="n">duration</span><span class="o">=</span><span class="mi">200</span><span class="p">)</span> <span class="n">predicted_videos</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">gif</span><span class="o">.</span><span class="n">getvalue</span><span class="p">())</span> <span class="c1"># Display the videos.</span> <span class="nb">print</span><span class="p">(</span><span class="s2">" Truth</span><span class="se">\t</span><span class="s2">Prediction"</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">predicted_videos</span><span class="p">),</span> <span class="mi">2</span><span class="p">):</span> <span class="c1"># Construct and display an `HBox` with the ground truth and prediction.</span> <span class="n">box</span> <span class="o">=</span> <span class="n">HBox</span><span class="p">(</span> <span class="p">[</span> <span class="n">widgets</span><span class="o">.</span><span class="n">Image</span><span class="p">(</span><span class="n">value</span><span class="o">=</span><span class="n">predicted_videos</span><span class="p">[</span><span class="n">i</span><span class="p">]),</span> <span class="n">widgets</span><span class="o">.</span><span class="n">Image</span><span class="p">(</span><span class="n">value</span><span class="o">=</span><span class="n">predicted_videos</span><span class="p">[</span><span class="n">i</span> <span class="o">+</span> <span class="mi">1</span><span class="p">]),</span> <span class="p">]</span> <span class="p">)</span> <span class="n">display</span><span class="p">(</span><span class="n">box</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 790ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 9ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 5ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 6ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 10ms/step Truth Prediction HBox(children=(Image(value=b'GIF89a@\x00@\x00\x87\x00\x00\xff\xff\xff\xfe\xfe\xfe\xfd\xfd\xfd\xfc\xfc\xfc\xf8\… HBox(children=(Image(value=b'GIF89a@\x00@\x00\x86\x00\x00\xff\xff\xff\xfd\xfd\xfd\xfc\xfc\xfc\xfb\xfb\xfb\xf4\… HBox(children=(Image(value=b'GIF89a@\x00@\x00\x86\x00\x00\xff\xff\xff\xfe\xfe\xfe\xfd\xfd\xfd\xfc\xfc\xfc\xfb\… HBox(children=(Image(value=b'GIF89a@\x00@\x00\x86\x00\x00\xff\xff\xff\xfe\xfe\xfe\xfd\xfd\xfd\xfc\xfc\xfc\xfb\… HBox(children=(Image(value=b'GIF89a@\x00@\x00\x86\x00\x00\xff\xff\xff\xfd\xfd\xfd\xfc\xfc\xfc\xf9\xf9\xf9\xf7\… </code></pre></div> </div> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#nextframe-video-prediction-with-convolutional-lstms'>Next-Frame Video Prediction with Convolutional LSTMs</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#setup'>Setup</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#dataset-construction'>Dataset Construction</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#data-visualization'>Data Visualization</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#model-construction'>Model Construction</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#model-training'>Model Training</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#frame-prediction-visualizations'>Frame Prediction Visualizations</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#predicted-videos'>Predicted Videos</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>