CINXE.COM
3D volumetric rendering with NeRF
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/vision/nerf/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: 3D volumetric rendering with NeRF"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: 3D volumetric rendering with NeRF"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>3D volumetric rendering with NeRF</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink active" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink2" href="/examples/vision/image_classification_from_scratch/">Image classification from scratch</a> <a class="nav-sublink2" href="/examples/vision/mnist_convnet/">Simple MNIST convnet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_efficientnet_fine_tuning/">Image classification via fine-tuning with EfficientNet</a> <a class="nav-sublink2" href="/examples/vision/image_classification_with_vision_transformer/">Image classification with Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/attention_mil_classification/">Classification using Attention-based Deep Multiple Instance Learning</a> <a class="nav-sublink2" href="/examples/vision/mlp_image_classification/">Image classification with modern MLP models</a> <a class="nav-sublink2" href="/examples/vision/mobilevit/">A mobile-friendly Transformer-based model for image classification</a> <a class="nav-sublink2" href="/examples/vision/xray_classification_with_tpus/">Pneumonia Classification on TPU</a> <a class="nav-sublink2" href="/examples/vision/cct/">Compact Convolutional Transformers</a> <a class="nav-sublink2" href="/examples/vision/convmixer/">Image classification with ConvMixer</a> <a class="nav-sublink2" href="/examples/vision/eanet/">Image classification with EANet (External Attention Transformer)</a> <a class="nav-sublink2" href="/examples/vision/involution/">Involutional neural networks</a> <a class="nav-sublink2" href="/examples/vision/perceiver_image_classification/">Image classification with Perceiver</a> <a class="nav-sublink2" href="/examples/vision/reptile/">Few-Shot learning with Reptile</a> <a class="nav-sublink2" href="/examples/vision/semisupervised_simclr/">Semi-supervised image classification using contrastive pretraining with SimCLR</a> <a class="nav-sublink2" href="/examples/vision/swin_transformers/">Image classification with Swin Transformers</a> <a class="nav-sublink2" href="/examples/vision/vit_small_ds/">Train a Vision Transformer on small datasets</a> <a class="nav-sublink2" href="/examples/vision/shiftvit/">A Vision Transformer without Attention</a> <a class="nav-sublink2" href="/examples/vision/image_classification_using_global_context_vision_transformer/">Image Classification using Global Context Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/oxford_pets_image_segmentation/">Image segmentation with a U-Net-like architecture</a> <a class="nav-sublink2" href="/examples/vision/deeplabv3_plus/">Multiclass semantic segmentation using DeepLabV3+</a> <a class="nav-sublink2" href="/examples/vision/basnet_segmentation/">Highly accurate boundaries segmentation using BASNet</a> <a class="nav-sublink2" href="/examples/vision/fully_convolutional_network/">Image Segmentation using Composable Fully-Convolutional Networks</a> <a class="nav-sublink2" href="/examples/vision/retinanet/">Object Detection with RetinaNet</a> <a class="nav-sublink2" href="/examples/vision/keypoint_detection/">Keypoint Detection with Transfer Learning</a> <a class="nav-sublink2" href="/examples/vision/object_detection_using_vision_transformer/">Object detection with Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/3D_image_classification/">3D image classification from CT scans</a> <a class="nav-sublink2" href="/examples/vision/depth_estimation/">Monocular depth estimation</a> <a class="nav-sublink2 active" href="/examples/vision/nerf/">3D volumetric rendering with NeRF</a> <a class="nav-sublink2" href="/examples/vision/pointnet_segmentation/">Point cloud segmentation with PointNet</a> <a class="nav-sublink2" href="/examples/vision/pointnet/">Point cloud classification</a> <a class="nav-sublink2" href="/examples/vision/captcha_ocr/">OCR model for reading Captchas</a> <a class="nav-sublink2" href="/examples/vision/handwriting_recognition/">Handwriting recognition</a> <a class="nav-sublink2" href="/examples/vision/autoencoder/">Convolutional autoencoder for image denoising</a> <a class="nav-sublink2" href="/examples/vision/mirnet/">Low-light image enhancement using MIRNet</a> <a class="nav-sublink2" href="/examples/vision/super_resolution_sub_pixel/">Image Super-Resolution using an Efficient Sub-Pixel CNN</a> <a class="nav-sublink2" href="/examples/vision/edsr/">Enhanced Deep Residual Networks for single-image super-resolution</a> <a class="nav-sublink2" href="/examples/vision/zero_dce/">Zero-DCE for low-light image enhancement</a> <a class="nav-sublink2" href="/examples/vision/cutmix/">CutMix data augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/mixup/">MixUp augmentation for image classification</a> <a class="nav-sublink2" href="/examples/vision/randaugment/">RandAugment for Image Classification for Improved Robustness</a> <a class="nav-sublink2" href="/examples/vision/image_captioning/">Image captioning</a> <a class="nav-sublink2" href="/examples/vision/nl_image_search/">Natural language image search with a Dual Encoder</a> <a class="nav-sublink2" href="/examples/vision/visualizing_what_convnets_learn/">Visualizing what convnets learn</a> <a class="nav-sublink2" href="/examples/vision/integrated_gradients/">Model interpretability with Integrated Gradients</a> <a class="nav-sublink2" href="/examples/vision/probing_vits/">Investigating Vision Transformer representations</a> <a class="nav-sublink2" href="/examples/vision/grad_cam/">Grad-CAM class activation visualization</a> <a class="nav-sublink2" href="/examples/vision/near_dup_search/">Near-duplicate image search</a> <a class="nav-sublink2" href="/examples/vision/semantic_image_clustering/">Semantic Image Clustering</a> <a class="nav-sublink2" href="/examples/vision/siamese_contrastive/">Image similarity estimation using a Siamese Network with a contrastive loss</a> <a class="nav-sublink2" href="/examples/vision/siamese_network/">Image similarity estimation using a Siamese Network with a triplet loss</a> <a class="nav-sublink2" href="/examples/vision/metric_learning/">Metric learning for image similarity search</a> <a class="nav-sublink2" href="/examples/vision/metric_learning_tf_similarity/">Metric learning for image similarity search using TensorFlow Similarity</a> <a class="nav-sublink2" href="/examples/vision/nnclr/">Self-supervised contrastive learning with NNCLR</a> <a class="nav-sublink2" href="/examples/vision/video_classification/">Video Classification with a CNN-RNN Architecture</a> <a class="nav-sublink2" href="/examples/vision/conv_lstm/">Next-Frame Video Prediction with Convolutional LSTMs</a> <a class="nav-sublink2" href="/examples/vision/video_transformers/">Video Classification with Transformers</a> <a class="nav-sublink2" href="/examples/vision/vivit/">Video Vision Transformer</a> <a class="nav-sublink2" href="/examples/vision/bit/">Image Classification using BigTransfer (BiT)</a> <a class="nav-sublink2" href="/examples/vision/gradient_centralization/">Gradient Centralization for Better Training Performance</a> <a class="nav-sublink2" href="/examples/vision/token_learner/">Learning to tokenize in Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/knowledge_distillation/">Knowledge Distillation</a> <a class="nav-sublink2" href="/examples/vision/fixres/">FixRes: Fixing train-test resolution discrepancy</a> <a class="nav-sublink2" href="/examples/vision/cait/">Class Attention Image Transformers with LayerScale</a> <a class="nav-sublink2" href="/examples/vision/patch_convnet/">Augmenting convnets with aggregated attention</a> <a class="nav-sublink2" href="/examples/vision/learnable_resizer/">Learning to Resize</a> <a class="nav-sublink2" href="/examples/vision/adamatch/">Semi-supervision and domain adaptation with AdaMatch</a> <a class="nav-sublink2" href="/examples/vision/barlow_twins/">Barlow Twins for Contrastive SSL</a> <a class="nav-sublink2" href="/examples/vision/consistency_training/">Consistency training with supervision</a> <a class="nav-sublink2" href="/examples/vision/deit/">Distilling Vision Transformers</a> <a class="nav-sublink2" href="/examples/vision/focal_modulation_network/">Focal Modulation: A replacement for Self-Attention</a> <a class="nav-sublink2" href="/examples/vision/forwardforward/">Using the Forward-Forward Algorithm for Image Classification</a> <a class="nav-sublink2" href="/examples/vision/masked_image_modeling/">Masked image modeling with Autoencoders</a> <a class="nav-sublink2" href="/examples/vision/sam/">Segment Anything Model with 🤗Transformers</a> <a class="nav-sublink2" href="/examples/vision/segformer/">Semantic segmentation with SegFormer and Hugging Face Transformers</a> <a class="nav-sublink2" href="/examples/vision/simsiam/">Self-supervised contrastive learning with SimSiam</a> <a class="nav-sublink2" href="/examples/vision/supervised-contrastive-learning/">Supervised Contrastive Learning</a> <a class="nav-sublink2" href="/examples/vision/temporal_latent_bottleneck/">When Recurrence meets Transformers</a> <a class="nav-sublink2" href="/examples/vision/yolov8/">Efficient Object Detection with YOLOV8 and KerasCV</a> <a class="nav-sublink" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparam Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/vision/'>Computer Vision</a> / 3D volumetric rendering with NeRF </div> <div class='k-content'> <h1 id="3d-volumetric-rendering-with-nerf">3D volumetric rendering with NeRF</h1> <p><strong>Authors:</strong> <a href="https://twitter.com/arig23498">Aritra Roy Gosthipaty</a>, <a href="https://twitter.com/ritwik_raha">Ritwik Raha</a><br> <strong>Date created:</strong> 2021/08/09<br> <strong>Last modified:</strong> 2023/11/13<br> <strong>Description:</strong> Minimal implementation of volumetric rendering as shown in NeRF.</p> <div class='example_version_banner keras_3'>ⓘ This example uses Keras 3</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/vision/ipynb/nerf.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/examples/vision/nerf.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="introduction">Introduction</h2> <p>In this example, we present a minimal implementation of the research paper <a href="https://arxiv.org/abs/2003.08934"><strong>NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis</strong></a> by Ben Mildenhall et. al. The authors have proposed an ingenious way to <em>synthesize novel views of a scene</em> by modelling the <em>volumetric scene function</em> through a neural network.</p> <p>To help you understand this intuitively, let's start with the following question: <em>would it be possible to give to a neural network the position of a pixel in an image, and ask the network to predict the color at that position?</em></p> <table> <thead> <tr> <th style="text-align: center;"><img alt="2d-train" src="https://i.imgur.com/DQM92vN.png" /></th> </tr> </thead> <tbody> <tr> <td style="text-align: center;"><strong>Figure 1</strong>: A neural network being given coordinates of an image</td> </tr> <tr> <td style="text-align: center;">as input and asked to predict the color at the coordinates.</td> </tr> </tbody> </table> <p>The neural network would hypothetically <em>memorize</em> (overfit on) the image. This means that our neural network would have encoded the entire image in its weights. We could query the neural network with each position, and it would eventually reconstruct the entire image.</p> <table> <thead> <tr> <th style="text-align: center;"><img alt="2d-test" src="https://i.imgur.com/6Qz5Hp1.png" /></th> </tr> </thead> <tbody> <tr> <td style="text-align: center;"><strong>Figure 2</strong>: The trained neural network recreates the image from scratch.</td> </tr> </tbody> </table> <p>A question now arises, how do we extend this idea to learn a 3D volumetric scene? Implementing a similar process as above would require the knowledge of every voxel (volume pixel). Turns out, this is quite a challenging task to do.</p> <p>The authors of the paper propose a minimal and elegant way to learn a 3D scene using a few images of the scene. They discard the use of voxels for training. The network learns to model the volumetric scene, thus generating novel views (images) of the 3D scene that the model was not shown at training time.</p> <p>There are a few prerequisites one needs to understand to fully appreciate the process. We structure the example in such a way that you will have all the required knowledge before starting the implementation.</p> <hr /> <h2 id="setup">Setup</h2> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">os</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s2">"KERAS_BACKEND"</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"tensorflow"</span> <span class="c1"># Setting random seed to obtain reproducible results.</span> <span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="nn">tf</span> <span class="n">tf</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">set_seed</span><span class="p">(</span><span class="mi">42</span><span class="p">)</span> <span class="kn">import</span> <span class="nn">keras</span> <span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">layers</span> <span class="kn">import</span> <span class="nn">os</span> <span class="kn">import</span> <span class="nn">glob</span> <span class="kn">import</span> <span class="nn">imageio.v2</span> <span class="k">as</span> <span class="nn">imageio</span> <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> <span class="kn">from</span> <span class="nn">tqdm</span> <span class="kn">import</span> <span class="n">tqdm</span> <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span> <span class="c1"># Initialize global variables.</span> <span class="n">AUTO</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">AUTOTUNE</span> <span class="n">BATCH_SIZE</span> <span class="o">=</span> <span class="mi">5</span> <span class="n">NUM_SAMPLES</span> <span class="o">=</span> <span class="mi">32</span> <span class="n">POS_ENCODE_DIMS</span> <span class="o">=</span> <span class="mi">16</span> <span class="n">EPOCHS</span> <span class="o">=</span> <span class="mi">20</span> </code></pre></div> <hr /> <h2 id="download-and-load-the-data">Download and load the data</h2> <p>The <code>npz</code> data file contains images, camera poses, and a focal length. The images are taken from multiple camera angles as shown in <strong>Figure 3</strong>.</p> <table> <thead> <tr> <th style="text-align: center;"><img alt="camera-angles" src="https://i.imgur.com/FLsi2is.png" /></th> </tr> </thead> <tbody> <tr> <td style="text-align: center;"><strong>Figure 3</strong>: Multiple camera angles <br></td> </tr> <tr> <td style="text-align: center;"><a href="https://arxiv.org/abs/2003.08934">Source: NeRF</a></td> </tr> </tbody> </table> <p>To understand camera poses in this context we have to first allow ourselves to think that a <em>camera is a mapping between the real-world and the 2-D image</em>.</p> <table> <thead> <tr> <th style="text-align: center;"><img alt="mapping" src="https://www.mathworks.com/help/vision/ug/calibration_coordinate_blocks.png" /></th> </tr> </thead> <tbody> <tr> <td style="text-align: center;"><strong>Figure 4</strong>: 3-D world to 2-D image mapping through a camera <br></td> </tr> <tr> <td style="text-align: center;"><a href="https://www.mathworks.com/help/vision/ug/camera-calibration.html">Source: Mathworks</a></td> </tr> </tbody> </table> <p>Consider the following equation:</p> <p><img src="https://i.imgur.com/TQHKx5v.pngg" width="100" height="50"/></p> <p>Where <strong>x</strong> is the 2-D image point, <strong>X</strong> is the 3-D world point and <strong>P</strong> is the camera-matrix. <strong>P</strong> is a 3 x 4 matrix that plays the crucial role of mapping the real world object onto an image plane.</p> <p><img src="https://i.imgur.com/chvJct5.png" width="300" height="100"/></p> <p>The camera-matrix is an <em>affine transform matrix</em> that is concatenated with a 3 x 1 column <code>[image height, image width, focal length]</code> to produce the <em>pose matrix</em>. This matrix is of dimensions 3 x 5 where the first 3 x 3 block is in the camera’s point of view. The axes are <code>[down, right, backwards]</code> or <code>[-y, x, z]</code> where the camera is facing forwards <code>-z</code>.</p> <table> <thead> <tr> <th style="text-align: center;"><img alt="camera-mapping" src="https://i.imgur.com/kvjqbiO.png" /></th> </tr> </thead> <tbody> <tr> <td style="text-align: center;"><strong>Figure 5</strong>: The affine transformation.</td> </tr> </tbody> </table> <p>The COLMAP frame is <code>[right, down, forwards]</code> or <code>[x, -y, -z]</code>. Read more about COLMAP <a href="https://colmap.github.io/">here</a>.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Download the data if it does not already exist.</span> <span class="n">url</span> <span class="o">=</span> <span class="p">(</span> <span class="s2">"http://cseweb.ucsd.edu/~viscomp/projects/LF/papers/ECCV20/nerf/tiny_nerf_data.npz"</span> <span class="p">)</span> <span class="n">data</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">get_file</span><span class="p">(</span><span class="n">origin</span><span class="o">=</span><span class="n">url</span><span class="p">)</span> <span class="n">data</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">load</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> <span class="n">images</span> <span class="o">=</span> <span class="n">data</span><span class="p">[</span><span class="s2">"images"</span><span class="p">]</span> <span class="n">im_shape</span> <span class="o">=</span> <span class="n">images</span><span class="o">.</span><span class="n">shape</span> <span class="p">(</span><span class="n">num_images</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">W</span><span class="p">,</span> <span class="n">_</span><span class="p">)</span> <span class="o">=</span> <span class="n">images</span><span class="o">.</span><span class="n">shape</span> <span class="p">(</span><span class="n">poses</span><span class="p">,</span> <span class="n">focal</span><span class="p">)</span> <span class="o">=</span> <span class="p">(</span><span class="n">data</span><span class="p">[</span><span class="s2">"poses"</span><span class="p">],</span> <span class="n">data</span><span class="p">[</span><span class="s2">"focal"</span><span class="p">])</span> <span class="c1"># Plot a random image from the dataset for visualization.</span> <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">images</span><span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="n">low</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">high</span><span class="o">=</span><span class="n">num_images</span><span class="p">)])</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> </code></pre></div> <p><img alt="png" src="/img/examples/vision/nerf/nerf_5_0.png" /></p> <hr /> <h2 id="data-pipeline">Data pipeline</h2> <p>Now that you've understood the notion of camera matrix and the mapping from a 3D scene to 2D images, let's talk about the inverse mapping, i.e. from 2D image to the 3D scene.</p> <p>We'll need to talk about volumetric rendering with ray casting and tracing, which are common computer graphics techniques. This section will help you get to speed with these techniques.</p> <p>Consider an image with <code>N</code> pixels. We shoot a ray through each pixel and sample some points on the ray. A ray is commonly parameterized by the equation <code>r(t) = o + td</code> where <code>t</code> is the parameter, <code>o</code> is the origin and <code>d</code> is the unit directional vector as shown in <strong>Figure 6</strong>.</p> <table> <thead> <tr> <th style="text-align: center;"><img alt="img" src="https://i.imgur.com/ywrqlzt.gif" /></th> </tr> </thead> <tbody> <tr> <td style="text-align: center;"><strong>Figure 6</strong>: <code>r(t) = o + td</code> where t is 3</td> </tr> </tbody> </table> <p>In <strong>Figure 7</strong>, we consider a ray, and we sample some random points on the ray. These sample points each have a unique location <code>(x, y, z)</code> and the ray has a viewing angle <code>(theta, phi)</code>. The viewing angle is particularly interesting as we can shoot a ray through a single pixel in a lot of different ways, each with a unique viewing angle. Another interesting thing to notice here is the noise that is added to the sampling process. We add a uniform noise to each sample so that the samples correspond to a continuous distribution. In <strong>Figure 7</strong> the blue points are the evenly distributed samples and the white points <code>(t1, t2, t3)</code> are randomly placed between the samples.</p> <table> <thead> <tr> <th style="text-align: center;"><img alt="img" src="https://i.imgur.com/r9TS2wv.gif" /></th> </tr> </thead> <tbody> <tr> <td style="text-align: center;"><strong>Figure 7</strong>: Sampling the points from a ray.</td> </tr> </tbody> </table> <p><strong>Figure 8</strong> showcases the entire sampling process in 3D, where you can see the rays coming out of the white image. This means that each pixel will have its corresponding rays and each ray will be sampled at distinct points.</p> <table> <thead> <tr> <th style="text-align: center;"><img alt="3-d rays" src="https://i.imgur.com/hr4D2g2.gif" /></th> </tr> </thead> <tbody> <tr> <td style="text-align: center;"><strong>Figure 8</strong>: Shooting rays from all the pixels of an image in 3-D</td> </tr> </tbody> </table> <p>These sampled points act as the input to the NeRF model. The model is then asked to predict the RGB color and the volume density at that point.</p> <table> <thead> <tr> <th style="text-align: center;"><img alt="3-Drender" src="https://i.imgur.com/HHb6tlQ.png" /></th> </tr> </thead> <tbody> <tr> <td style="text-align: center;"><strong>Figure 9</strong>: Data pipeline <br></td> </tr> <tr> <td style="text-align: center;"><a href="https://arxiv.org/abs/2003.08934">Source: NeRF</a></td> </tr> </tbody> </table> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">encode_position</span><span class="p">(</span><span class="n">x</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Encodes the position into its corresponding Fourier feature.</span> <span class="sd"> Args:</span> <span class="sd"> x: The input coordinate.</span> <span class="sd"> Returns:</span> <span class="sd"> Fourier features tensors of the position.</span> <span class="sd"> """</span> <span class="n">positions</span> <span class="o">=</span> <span class="p">[</span><span class="n">x</span><span class="p">]</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">POS_ENCODE_DIMS</span><span class="p">):</span> <span class="k">for</span> <span class="n">fn</span> <span class="ow">in</span> <span class="p">[</span><span class="n">tf</span><span class="o">.</span><span class="n">sin</span><span class="p">,</span> <span class="n">tf</span><span class="o">.</span><span class="n">cos</span><span class="p">]:</span> <span class="n">positions</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">fn</span><span class="p">(</span><span class="mf">2.0</span><span class="o">**</span><span class="n">i</span> <span class="o">*</span> <span class="n">x</span><span class="p">))</span> <span class="k">return</span> <span class="n">tf</span><span class="o">.</span><span class="n">concat</span><span class="p">(</span><span class="n">positions</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="k">def</span> <span class="nf">get_rays</span><span class="p">(</span><span class="n">height</span><span class="p">,</span> <span class="n">width</span><span class="p">,</span> <span class="n">focal</span><span class="p">,</span> <span class="n">pose</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Computes origin point and direction vector of rays.</span> <span class="sd"> Args:</span> <span class="sd"> height: Height of the image.</span> <span class="sd"> width: Width of the image.</span> <span class="sd"> focal: The focal length between the images and the camera.</span> <span class="sd"> pose: The pose matrix of the camera.</span> <span class="sd"> Returns:</span> <span class="sd"> Tuple of origin point and direction vector for rays.</span> <span class="sd"> """</span> <span class="c1"># Build a meshgrid for the rays.</span> <span class="n">i</span><span class="p">,</span> <span class="n">j</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">meshgrid</span><span class="p">(</span> <span class="n">tf</span><span class="o">.</span><span class="n">range</span><span class="p">(</span><span class="n">width</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span> <span class="n">tf</span><span class="o">.</span><span class="n">range</span><span class="p">(</span><span class="n">height</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">float32</span><span class="p">),</span> <span class="n">indexing</span><span class="o">=</span><span class="s2">"xy"</span><span class="p">,</span> <span class="p">)</span> <span class="c1"># Normalize the x axis coordinates.</span> <span class="n">transformed_i</span> <span class="o">=</span> <span class="p">(</span><span class="n">i</span> <span class="o">-</span> <span class="n">width</span> <span class="o">*</span> <span class="mf">0.5</span><span class="p">)</span> <span class="o">/</span> <span class="n">focal</span> <span class="c1"># Normalize the y axis coordinates.</span> <span class="n">transformed_j</span> <span class="o">=</span> <span class="p">(</span><span class="n">j</span> <span class="o">-</span> <span class="n">height</span> <span class="o">*</span> <span class="mf">0.5</span><span class="p">)</span> <span class="o">/</span> <span class="n">focal</span> <span class="c1"># Create the direction unit vectors.</span> <span class="n">directions</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">stack</span><span class="p">([</span><span class="n">transformed_i</span><span class="p">,</span> <span class="o">-</span><span class="n">transformed_j</span><span class="p">,</span> <span class="o">-</span><span class="n">tf</span><span class="o">.</span><span class="n">ones_like</span><span class="p">(</span><span class="n">i</span><span class="p">)],</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="c1"># Get the camera matrix.</span> <span class="n">camera_matrix</span> <span class="o">=</span> <span class="n">pose</span><span class="p">[:</span><span class="mi">3</span><span class="p">,</span> <span class="p">:</span><span class="mi">3</span><span class="p">]</span> <span class="n">height_width_focal</span> <span class="o">=</span> <span class="n">pose</span><span class="p">[:</span><span class="mi">3</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="c1"># Get origins and directions for the rays.</span> <span class="n">transformed_dirs</span> <span class="o">=</span> <span class="n">directions</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="n">camera_dirs</span> <span class="o">=</span> <span class="n">transformed_dirs</span> <span class="o">*</span> <span class="n">camera_matrix</span> <span class="n">ray_directions</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">reduce_sum</span><span class="p">(</span><span class="n">camera_dirs</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="n">ray_origins</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">broadcast_to</span><span class="p">(</span><span class="n">height_width_focal</span><span class="p">,</span> <span class="n">tf</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">ray_directions</span><span class="p">))</span> <span class="c1"># Return the origins and directions.</span> <span class="k">return</span> <span class="p">(</span><span class="n">ray_origins</span><span class="p">,</span> <span class="n">ray_directions</span><span class="p">)</span> <span class="k">def</span> <span class="nf">render_flat_rays</span><span class="p">(</span><span class="n">ray_origins</span><span class="p">,</span> <span class="n">ray_directions</span><span class="p">,</span> <span class="n">near</span><span class="p">,</span> <span class="n">far</span><span class="p">,</span> <span class="n">num_samples</span><span class="p">,</span> <span class="n">rand</span><span class="o">=</span><span class="kc">False</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Renders the rays and flattens it.</span> <span class="sd"> Args:</span> <span class="sd"> ray_origins: The origin points for rays.</span> <span class="sd"> ray_directions: The direction unit vectors for the rays.</span> <span class="sd"> near: The near bound of the volumetric scene.</span> <span class="sd"> far: The far bound of the volumetric scene.</span> <span class="sd"> num_samples: Number of sample points in a ray.</span> <span class="sd"> rand: Choice for randomising the sampling strategy.</span> <span class="sd"> Returns:</span> <span class="sd"> Tuple of flattened rays and sample points on each rays.</span> <span class="sd"> """</span> <span class="c1"># Compute 3D query points.</span> <span class="c1"># Equation: r(t) = o+td -> Building the "t" here.</span> <span class="n">t_vals</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">linspace</span><span class="p">(</span><span class="n">near</span><span class="p">,</span> <span class="n">far</span><span class="p">,</span> <span class="n">num_samples</span><span class="p">)</span> <span class="k">if</span> <span class="n">rand</span><span class="p">:</span> <span class="c1"># Inject uniform noise into sample space to make the sampling</span> <span class="c1"># continuous.</span> <span class="n">shape</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">ray_origins</span><span class="o">.</span><span class="n">shape</span><span class="p">[:</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span> <span class="o">+</span> <span class="p">[</span><span class="n">num_samples</span><span class="p">]</span> <span class="n">noise</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">uniform</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="n">shape</span><span class="p">)</span> <span class="o">*</span> <span class="p">(</span><span class="n">far</span> <span class="o">-</span> <span class="n">near</span><span class="p">)</span> <span class="o">/</span> <span class="n">num_samples</span> <span class="n">t_vals</span> <span class="o">=</span> <span class="n">t_vals</span> <span class="o">+</span> <span class="n">noise</span> <span class="c1"># Equation: r(t) = o + td -> Building the "r" here.</span> <span class="n">rays</span> <span class="o">=</span> <span class="n">ray_origins</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">+</span> <span class="p">(</span> <span class="n">ray_directions</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="p">:]</span> <span class="o">*</span> <span class="n">t_vals</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="kc">None</span><span class="p">]</span> <span class="p">)</span> <span class="n">rays_flat</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">rays</span><span class="p">,</span> <span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">])</span> <span class="n">rays_flat</span> <span class="o">=</span> <span class="n">encode_position</span><span class="p">(</span><span class="n">rays_flat</span><span class="p">)</span> <span class="k">return</span> <span class="p">(</span><span class="n">rays_flat</span><span class="p">,</span> <span class="n">t_vals</span><span class="p">)</span> <span class="k">def</span> <span class="nf">map_fn</span><span class="p">(</span><span class="n">pose</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Maps individual pose to flattened rays and sample points.</span> <span class="sd"> Args:</span> <span class="sd"> pose: The pose matrix of the camera.</span> <span class="sd"> Returns:</span> <span class="sd"> Tuple of flattened rays and sample points corresponding to the</span> <span class="sd"> camera pose.</span> <span class="sd"> """</span> <span class="p">(</span><span class="n">ray_origins</span><span class="p">,</span> <span class="n">ray_directions</span><span class="p">)</span> <span class="o">=</span> <span class="n">get_rays</span><span class="p">(</span><span class="n">height</span><span class="o">=</span><span class="n">H</span><span class="p">,</span> <span class="n">width</span><span class="o">=</span><span class="n">W</span><span class="p">,</span> <span class="n">focal</span><span class="o">=</span><span class="n">focal</span><span class="p">,</span> <span class="n">pose</span><span class="o">=</span><span class="n">pose</span><span class="p">)</span> <span class="p">(</span><span class="n">rays_flat</span><span class="p">,</span> <span class="n">t_vals</span><span class="p">)</span> <span class="o">=</span> <span class="n">render_flat_rays</span><span class="p">(</span> <span class="n">ray_origins</span><span class="o">=</span><span class="n">ray_origins</span><span class="p">,</span> <span class="n">ray_directions</span><span class="o">=</span><span class="n">ray_directions</span><span class="p">,</span> <span class="n">near</span><span class="o">=</span><span class="mf">2.0</span><span class="p">,</span> <span class="n">far</span><span class="o">=</span><span class="mf">6.0</span><span class="p">,</span> <span class="n">num_samples</span><span class="o">=</span><span class="n">NUM_SAMPLES</span><span class="p">,</span> <span class="n">rand</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="p">)</span> <span class="k">return</span> <span class="p">(</span><span class="n">rays_flat</span><span class="p">,</span> <span class="n">t_vals</span><span class="p">)</span> <span class="c1"># Create the training split.</span> <span class="n">split_index</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">num_images</span> <span class="o">*</span> <span class="mf">0.8</span><span class="p">)</span> <span class="c1"># Split the images into training and validation.</span> <span class="n">train_images</span> <span class="o">=</span> <span class="n">images</span><span class="p">[:</span><span class="n">split_index</span><span class="p">]</span> <span class="n">val_images</span> <span class="o">=</span> <span class="n">images</span><span class="p">[</span><span class="n">split_index</span><span class="p">:]</span> <span class="c1"># Split the poses into training and validation.</span> <span class="n">train_poses</span> <span class="o">=</span> <span class="n">poses</span><span class="p">[:</span><span class="n">split_index</span><span class="p">]</span> <span class="n">val_poses</span> <span class="o">=</span> <span class="n">poses</span><span class="p">[</span><span class="n">split_index</span><span class="p">:]</span> <span class="c1"># Make the training pipeline.</span> <span class="n">train_img_ds</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">from_tensor_slices</span><span class="p">(</span><span class="n">train_images</span><span class="p">)</span> <span class="n">train_pose_ds</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">from_tensor_slices</span><span class="p">(</span><span class="n">train_poses</span><span class="p">)</span> <span class="n">train_ray_ds</span> <span class="o">=</span> <span class="n">train_pose_ds</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">map_fn</span><span class="p">,</span> <span class="n">num_parallel_calls</span><span class="o">=</span><span class="n">AUTO</span><span class="p">)</span> <span class="n">training_ds</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">zip</span><span class="p">((</span><span class="n">train_img_ds</span><span class="p">,</span> <span class="n">train_ray_ds</span><span class="p">))</span> <span class="n">train_ds</span> <span class="o">=</span> <span class="p">(</span> <span class="n">training_ds</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="n">BATCH_SIZE</span><span class="p">)</span> <span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="n">BATCH_SIZE</span><span class="p">,</span> <span class="n">drop_remainder</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">num_parallel_calls</span><span class="o">=</span><span class="n">AUTO</span><span class="p">)</span> <span class="o">.</span><span class="n">prefetch</span><span class="p">(</span><span class="n">AUTO</span><span class="p">)</span> <span class="p">)</span> <span class="c1"># Make the validation pipeline.</span> <span class="n">val_img_ds</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">from_tensor_slices</span><span class="p">(</span><span class="n">val_images</span><span class="p">)</span> <span class="n">val_pose_ds</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">from_tensor_slices</span><span class="p">(</span><span class="n">val_poses</span><span class="p">)</span> <span class="n">val_ray_ds</span> <span class="o">=</span> <span class="n">val_pose_ds</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">map_fn</span><span class="p">,</span> <span class="n">num_parallel_calls</span><span class="o">=</span><span class="n">AUTO</span><span class="p">)</span> <span class="n">validation_ds</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">zip</span><span class="p">((</span><span class="n">val_img_ds</span><span class="p">,</span> <span class="n">val_ray_ds</span><span class="p">))</span> <span class="n">val_ds</span> <span class="o">=</span> <span class="p">(</span> <span class="n">validation_ds</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="n">BATCH_SIZE</span><span class="p">)</span> <span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="n">BATCH_SIZE</span><span class="p">,</span> <span class="n">drop_remainder</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">num_parallel_calls</span><span class="o">=</span><span class="n">AUTO</span><span class="p">)</span> <span class="o">.</span><span class="n">prefetch</span><span class="p">(</span><span class="n">AUTO</span><span class="p">)</span> <span class="p">)</span> </code></pre></div> <hr /> <h2 id="nerf-model">NeRF model</h2> <p>The model is a multi-layer perceptron (MLP), with ReLU as its non-linearity.</p> <p>An excerpt from the paper:</p> <p><em>"We encourage the representation to be multiview-consistent by restricting the network to predict the volume density sigma as a function of only the location <code>x</code>, while allowing the RGB color <code>c</code> to be predicted as a function of both location and viewing direction. To accomplish this, the MLP first processes the input 3D coordinate <code>x</code> with 8 fully-connected layers (using ReLU activations and 256 channels per layer), and outputs sigma and a 256-dimensional feature vector. This feature vector is then concatenated with the camera ray's viewing direction and passed to one additional fully-connected layer (using a ReLU activation and 128 channels) that output the view-dependent RGB color."</em></p> <p>Here we have gone for a minimal implementation and have used 64 Dense units instead of 256 as mentioned in the paper.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">get_nerf_model</span><span class="p">(</span><span class="n">num_layers</span><span class="p">,</span> <span class="n">num_pos</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Generates the NeRF neural network.</span> <span class="sd"> Args:</span> <span class="sd"> num_layers: The number of MLP layers.</span> <span class="sd"> num_pos: The number of dimensions of positional encoding.</span> <span class="sd"> Returns:</span> <span class="sd"> The `keras` model.</span> <span class="sd"> """</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">num_pos</span><span class="p">,</span> <span class="mi">2</span> <span class="o">*</span> <span class="mi">3</span> <span class="o">*</span> <span class="n">POS_ENCODE_DIMS</span> <span class="o">+</span> <span class="mi">3</span><span class="p">))</span> <span class="n">x</span> <span class="o">=</span> <span class="n">inputs</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_layers</span><span class="p">):</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">units</span><span class="o">=</span><span class="mi">64</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="k">if</span> <span class="n">i</span> <span class="o">%</span> <span class="mi">4</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">i</span> <span class="o">></span> <span class="mi">0</span><span class="p">:</span> <span class="c1"># Inject residual connection.</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">concatenate</span><span class="p">([</span><span class="n">x</span><span class="p">,</span> <span class="n">inputs</span><span class="p">],</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">units</span><span class="o">=</span><span class="mi">4</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="k">return</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="o">=</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="o">=</span><span class="n">outputs</span><span class="p">)</span> <span class="k">def</span> <span class="nf">render_rgb_depth</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">rays_flat</span><span class="p">,</span> <span class="n">t_vals</span><span class="p">,</span> <span class="n">rand</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">train</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Generates the RGB image and depth map from model prediction.</span> <span class="sd"> Args:</span> <span class="sd"> model: The MLP model that is trained to predict the rgb and</span> <span class="sd"> volume density of the volumetric scene.</span> <span class="sd"> rays_flat: The flattened rays that serve as the input to</span> <span class="sd"> the NeRF model.</span> <span class="sd"> t_vals: The sample points for the rays.</span> <span class="sd"> rand: Choice to randomise the sampling strategy.</span> <span class="sd"> train: Whether the model is in the training or testing phase.</span> <span class="sd"> Returns:</span> <span class="sd"> Tuple of rgb image and depth map.</span> <span class="sd"> """</span> <span class="c1"># Get the predictions from the nerf model and reshape it.</span> <span class="k">if</span> <span class="n">train</span><span class="p">:</span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">rays_flat</span><span class="p">)</span> <span class="k">else</span><span class="p">:</span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">rays_flat</span><span class="p">)</span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">predictions</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">BATCH_SIZE</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">W</span><span class="p">,</span> <span class="n">NUM_SAMPLES</span><span class="p">,</span> <span class="mi">4</span><span class="p">))</span> <span class="c1"># Slice the predictions into rgb and sigma.</span> <span class="n">rgb</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">sigmoid</span><span class="p">(</span><span class="n">predictions</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="p">:</span><span class="o">-</span><span class="mi">1</span><span class="p">])</span> <span class="n">sigma_a</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">nn</span><span class="o">.</span><span class="n">relu</span><span class="p">(</span><span class="n">predictions</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">])</span> <span class="c1"># Get the distance of adjacent intervals.</span> <span class="n">delta</span> <span class="o">=</span> <span class="n">t_vals</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="mi">1</span><span class="p">:]</span> <span class="o">-</span> <span class="n">t_vals</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="p">:</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="c1"># delta shape = (num_samples)</span> <span class="k">if</span> <span class="n">rand</span><span class="p">:</span> <span class="n">delta</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">concat</span><span class="p">(</span> <span class="p">[</span><span class="n">delta</span><span class="p">,</span> <span class="n">tf</span><span class="o">.</span><span class="n">broadcast_to</span><span class="p">([</span><span class="mf">1e10</span><span class="p">],</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">BATCH_SIZE</span><span class="p">,</span> <span class="n">H</span><span class="p">,</span> <span class="n">W</span><span class="p">,</span> <span class="mi">1</span><span class="p">))],</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span> <span class="p">)</span> <span class="n">alpha</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">-</span> <span class="n">tf</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="o">-</span><span class="n">sigma_a</span> <span class="o">*</span> <span class="n">delta</span><span class="p">)</span> <span class="k">else</span><span class="p">:</span> <span class="n">delta</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">concat</span><span class="p">(</span> <span class="p">[</span><span class="n">delta</span><span class="p">,</span> <span class="n">tf</span><span class="o">.</span><span class="n">broadcast_to</span><span class="p">([</span><span class="mf">1e10</span><span class="p">],</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">BATCH_SIZE</span><span class="p">,</span> <span class="mi">1</span><span class="p">))],</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span> <span class="p">)</span> <span class="n">alpha</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">-</span> <span class="n">tf</span><span class="o">.</span><span class="n">exp</span><span class="p">(</span><span class="o">-</span><span class="n">sigma_a</span> <span class="o">*</span> <span class="n">delta</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">,</span> <span class="p">:])</span> <span class="c1"># Get transmittance.</span> <span class="n">exp_term</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">-</span> <span class="n">alpha</span> <span class="n">epsilon</span> <span class="o">=</span> <span class="mf">1e-10</span> <span class="n">transmittance</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">math</span><span class="o">.</span><span class="n">cumprod</span><span class="p">(</span><span class="n">exp_term</span> <span class="o">+</span> <span class="n">epsilon</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">,</span> <span class="n">exclusive</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="n">weights</span> <span class="o">=</span> <span class="n">alpha</span> <span class="o">*</span> <span class="n">transmittance</span> <span class="n">rgb</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">reduce_sum</span><span class="p">(</span><span class="n">weights</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="kc">None</span><span class="p">]</span> <span class="o">*</span> <span class="n">rgb</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">2</span><span class="p">)</span> <span class="k">if</span> <span class="n">rand</span><span class="p">:</span> <span class="n">depth_map</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">reduce_sum</span><span class="p">(</span><span class="n">weights</span> <span class="o">*</span> <span class="n">t_vals</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="k">else</span><span class="p">:</span> <span class="n">depth_map</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">reduce_sum</span><span class="p">(</span><span class="n">weights</span> <span class="o">*</span> <span class="n">t_vals</span><span class="p">[:,</span> <span class="kc">None</span><span class="p">,</span> <span class="kc">None</span><span class="p">],</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span> <span class="k">return</span> <span class="p">(</span><span class="n">rgb</span><span class="p">,</span> <span class="n">depth_map</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="training">Training</h2> <p>The training step is implemented as part of a custom <a href="/api/models/model#model-class"><code>keras.Model</code></a> subclass so that we can make use of the <code>model.fit</code> functionality.</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">NeRF</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">nerf_model</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">nerf_model</span> <span class="o">=</span> <span class="n">nerf_model</span> <span class="k">def</span> <span class="nf">compile</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">optimizer</span><span class="p">,</span> <span class="n">loss_fn</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="n">compile</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span> <span class="o">=</span> <span class="n">optimizer</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_fn</span> <span class="o">=</span> <span class="n">loss_fn</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_tracker</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">Mean</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">"loss"</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">psnr_metric</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">Mean</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">"psnr"</span><span class="p">)</span> <span class="k">def</span> <span class="nf">train_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="c1"># Get the images and the rays.</span> <span class="p">(</span><span class="n">images</span><span class="p">,</span> <span class="n">rays</span><span class="p">)</span> <span class="o">=</span> <span class="n">inputs</span> <span class="p">(</span><span class="n">rays_flat</span><span class="p">,</span> <span class="n">t_vals</span><span class="p">)</span> <span class="o">=</span> <span class="n">rays</span> <span class="k">with</span> <span class="n">tf</span><span class="o">.</span><span class="n">GradientTape</span><span class="p">()</span> <span class="k">as</span> <span class="n">tape</span><span class="p">:</span> <span class="c1"># Get the predictions from the model.</span> <span class="n">rgb</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">render_rgb_depth</span><span class="p">(</span> <span class="n">model</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">nerf_model</span><span class="p">,</span> <span class="n">rays_flat</span><span class="o">=</span><span class="n">rays_flat</span><span class="p">,</span> <span class="n">t_vals</span><span class="o">=</span><span class="n">t_vals</span><span class="p">,</span> <span class="n">rand</span><span class="o">=</span><span class="kc">True</span> <span class="p">)</span> <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_fn</span><span class="p">(</span><span class="n">images</span><span class="p">,</span> <span class="n">rgb</span><span class="p">)</span> <span class="c1"># Get the trainable variables.</span> <span class="n">trainable_variables</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">nerf_model</span><span class="o">.</span><span class="n">trainable_variables</span> <span class="c1"># Get the gradeints of the trainiable variables with respect to the loss.</span> <span class="n">gradients</span> <span class="o">=</span> <span class="n">tape</span><span class="o">.</span><span class="n">gradient</span><span class="p">(</span><span class="n">loss</span><span class="p">,</span> <span class="n">trainable_variables</span><span class="p">)</span> <span class="c1"># Apply the grads and optimize the model.</span> <span class="bp">self</span><span class="o">.</span><span class="n">optimizer</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="n">gradients</span><span class="p">,</span> <span class="n">trainable_variables</span><span class="p">))</span> <span class="c1"># Get the PSNR of the reconstructed images and the source images.</span> <span class="n">psnr</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">psnr</span><span class="p">(</span><span class="n">images</span><span class="p">,</span> <span class="n">rgb</span><span class="p">,</span> <span class="n">max_val</span><span class="o">=</span><span class="mf">1.0</span><span class="p">)</span> <span class="c1"># Compute our own metrics</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_tracker</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">loss</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">psnr_metric</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">psnr</span><span class="p">)</span> <span class="k">return</span> <span class="p">{</span><span class="s2">"loss"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_tracker</span><span class="o">.</span><span class="n">result</span><span class="p">(),</span> <span class="s2">"psnr"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">psnr_metric</span><span class="o">.</span><span class="n">result</span><span class="p">()}</span> <span class="k">def</span> <span class="nf">test_step</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="c1"># Get the images and the rays.</span> <span class="p">(</span><span class="n">images</span><span class="p">,</span> <span class="n">rays</span><span class="p">)</span> <span class="o">=</span> <span class="n">inputs</span> <span class="p">(</span><span class="n">rays_flat</span><span class="p">,</span> <span class="n">t_vals</span><span class="p">)</span> <span class="o">=</span> <span class="n">rays</span> <span class="c1"># Get the predictions from the model.</span> <span class="n">rgb</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">render_rgb_depth</span><span class="p">(</span> <span class="n">model</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">nerf_model</span><span class="p">,</span> <span class="n">rays_flat</span><span class="o">=</span><span class="n">rays_flat</span><span class="p">,</span> <span class="n">t_vals</span><span class="o">=</span><span class="n">t_vals</span><span class="p">,</span> <span class="n">rand</span><span class="o">=</span><span class="kc">True</span> <span class="p">)</span> <span class="n">loss</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_fn</span><span class="p">(</span><span class="n">images</span><span class="p">,</span> <span class="n">rgb</span><span class="p">)</span> <span class="c1"># Get the PSNR of the reconstructed images and the source images.</span> <span class="n">psnr</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">psnr</span><span class="p">(</span><span class="n">images</span><span class="p">,</span> <span class="n">rgb</span><span class="p">,</span> <span class="n">max_val</span><span class="o">=</span><span class="mf">1.0</span><span class="p">)</span> <span class="c1"># Compute our own metrics</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_tracker</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">loss</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">psnr_metric</span><span class="o">.</span><span class="n">update_state</span><span class="p">(</span><span class="n">psnr</span><span class="p">)</span> <span class="k">return</span> <span class="p">{</span><span class="s2">"loss"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">loss_tracker</span><span class="o">.</span><span class="n">result</span><span class="p">(),</span> <span class="s2">"psnr"</span><span class="p">:</span> <span class="bp">self</span><span class="o">.</span><span class="n">psnr_metric</span><span class="o">.</span><span class="n">result</span><span class="p">()}</span> <span class="nd">@property</span> <span class="k">def</span> <span class="nf">metrics</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span> <span class="k">return</span> <span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">loss_tracker</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">psnr_metric</span><span class="p">]</span> <span class="n">test_imgs</span><span class="p">,</span> <span class="n">test_rays</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="nb">iter</span><span class="p">(</span><span class="n">train_ds</span><span class="p">))</span> <span class="n">test_rays_flat</span><span class="p">,</span> <span class="n">test_t_vals</span> <span class="o">=</span> <span class="n">test_rays</span> <span class="n">loss_list</span> <span class="o">=</span> <span class="p">[]</span> <span class="k">class</span> <span class="nc">TrainMonitor</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">callbacks</span><span class="o">.</span><span class="n">Callback</span><span class="p">):</span> <span class="k">def</span> <span class="nf">on_epoch_end</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">epoch</span><span class="p">,</span> <span class="n">logs</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">logs</span><span class="p">[</span><span class="s2">"loss"</span><span class="p">]</span> <span class="n">loss_list</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">loss</span><span class="p">)</span> <span class="n">test_recons_images</span><span class="p">,</span> <span class="n">depth_maps</span> <span class="o">=</span> <span class="n">render_rgb_depth</span><span class="p">(</span> <span class="n">model</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">model</span><span class="o">.</span><span class="n">nerf_model</span><span class="p">,</span> <span class="n">rays_flat</span><span class="o">=</span><span class="n">test_rays_flat</span><span class="p">,</span> <span class="n">t_vals</span><span class="o">=</span><span class="n">test_t_vals</span><span class="p">,</span> <span class="n">rand</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="p">)</span> <span class="c1"># Plot the rgb, depth and the loss plot.</span> <span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">(</span><span class="n">nrows</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">ncols</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">20</span><span class="p">,</span> <span class="mi">5</span><span class="p">))</span> <span class="n">ax</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">array_to_img</span><span class="p">(</span><span class="n">test_recons_images</span><span class="p">[</span><span class="mi">0</span><span class="p">]))</span> <span class="n">ax</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Predicted Image: </span><span class="si">{</span><span class="n">epoch</span><span class="si">:</span><span class="s2">03d</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> <span class="n">ax</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">array_to_img</span><span class="p">(</span><span class="n">depth_maps</span><span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="o">...</span><span class="p">,</span> <span class="kc">None</span><span class="p">]))</span> <span class="n">ax</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Depth Map: </span><span class="si">{</span><span class="n">epoch</span><span class="si">:</span><span class="s2">03d</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> <span class="n">ax</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">loss_list</span><span class="p">)</span> <span class="n">ax</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">set_xticks</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="n">EPOCHS</span> <span class="o">+</span> <span class="mi">1</span><span class="p">,</span> <span class="mf">5.0</span><span class="p">))</span> <span class="n">ax</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Loss Plot: </span><span class="si">{</span><span class="n">epoch</span><span class="si">:</span><span class="s2">03d</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> <span class="n">fig</span><span class="o">.</span><span class="n">savefig</span><span class="p">(</span><span class="sa">f</span><span class="s2">"images/</span><span class="si">{</span><span class="n">epoch</span><span class="si">:</span><span class="s2">03d</span><span class="si">}</span><span class="s2">.png"</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> <span class="n">plt</span><span class="o">.</span><span class="n">close</span><span class="p">()</span> <span class="n">num_pos</span> <span class="o">=</span> <span class="n">H</span> <span class="o">*</span> <span class="n">W</span> <span class="o">*</span> <span class="n">NUM_SAMPLES</span> <span class="n">nerf_model</span> <span class="o">=</span> <span class="n">get_nerf_model</span><span class="p">(</span><span class="n">num_layers</span><span class="o">=</span><span class="mi">8</span><span class="p">,</span> <span class="n">num_pos</span><span class="o">=</span><span class="n">num_pos</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">NeRF</span><span class="p">(</span><span class="n">nerf_model</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">Adam</span><span class="p">(),</span> <span class="n">loss_fn</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">MeanSquaredError</span><span class="p">()</span> <span class="p">)</span> <span class="c1"># Create a directory to save the images during training.</span> <span class="k">if</span> <span class="ow">not</span> <span class="n">os</span><span class="o">.</span><span class="n">path</span><span class="o">.</span><span class="n">exists</span><span class="p">(</span><span class="s2">"images"</span><span class="p">):</span> <span class="n">os</span><span class="o">.</span><span class="n">makedirs</span><span class="p">(</span><span class="s2">"images"</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span> <span class="n">train_ds</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="n">val_ds</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">BATCH_SIZE</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="n">EPOCHS</span><span class="p">,</span> <span class="n">callbacks</span><span class="o">=</span><span class="p">[</span><span class="n">TrainMonitor</span><span class="p">()],</span> <span class="p">)</span> <span class="k">def</span> <span class="nf">create_gif</span><span class="p">(</span><span class="n">path_to_images</span><span class="p">,</span> <span class="n">name_gif</span><span class="p">):</span> <span class="n">filenames</span> <span class="o">=</span> <span class="n">glob</span><span class="o">.</span><span class="n">glob</span><span class="p">(</span><span class="n">path_to_images</span><span class="p">)</span> <span class="n">filenames</span> <span class="o">=</span> <span class="nb">sorted</span><span class="p">(</span><span class="n">filenames</span><span class="p">)</span> <span class="n">images</span> <span class="o">=</span> <span class="p">[]</span> <span class="k">for</span> <span class="n">filename</span> <span class="ow">in</span> <span class="n">tqdm</span><span class="p">(</span><span class="n">filenames</span><span class="p">):</span> <span class="n">images</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">imageio</span><span class="o">.</span><span class="n">imread</span><span class="p">(</span><span class="n">filename</span><span class="p">))</span> <span class="n">kargs</span> <span class="o">=</span> <span class="p">{</span><span class="s2">"duration"</span><span class="p">:</span> <span class="mf">0.25</span><span class="p">}</span> <span class="n">imageio</span><span class="o">.</span><span class="n">mimsave</span><span class="p">(</span><span class="n">name_gif</span><span class="p">,</span> <span class="n">images</span><span class="p">,</span> <span class="s2">"GIF"</span><span class="p">,</span> <span class="o">**</span><span class="n">kargs</span><span class="p">)</span> <span class="n">create_gif</span><span class="p">(</span><span class="s2">"images/*.png"</span><span class="p">,</span> <span class="s2">"training.gif"</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 1/20 1/16 ━[37m━━━━━━━━━━━━━━━━━━━ 3:54 16s/step - loss: 0.0948 - psnr: 10.6234 WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1699908753.457905 65271 device_compiler.h:187] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process. 1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 924ms/step </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/nerf/nerf_11_3.png" /></p> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 16/16 ━━━━━━━━━━━━━━━━━━━━ 29s 889ms/step - loss: 0.1091 - psnr: 9.8283 - val_loss: 0.0753 - val_psnr: 11.5686 Epoch 2/20 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 477ms/step </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/nerf/nerf_11_5.png" /></p> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 16/16 ━━━━━━━━━━━━━━━━━━━━ 16s 926ms/step - loss: 0.0633 - psnr: 12.4819 - val_loss: 0.0657 - val_psnr: 12.1781 Epoch 3/20 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 474ms/step </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/nerf/nerf_11_7.png" /></p> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 16/16 ━━━━━━━━━━━━━━━━━━━━ 16s 921ms/step - loss: 0.0589 - psnr: 12.6268 - val_loss: 0.0637 - val_psnr: 12.3413 Epoch 4/20 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 470ms/step </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/nerf/nerf_11_9.png" /></p> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 16/16 ━━━━━━━━━━━━━━━━━━━━ 15s 915ms/step - loss: 0.0573 - psnr: 12.8150 - val_loss: 0.0617 - val_psnr: 12.4789 Epoch 5/20 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 477ms/step </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/nerf/nerf_11_11.png" /></p> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 16/16 ━━━━━━━━━━━━━━━━━━━━ 15s 918ms/step - loss: 0.0552 - psnr: 12.9703 - val_loss: 0.0594 - val_psnr: 12.6457 Epoch 6/20 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 476ms/step </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/nerf/nerf_11_13.png" /></p> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 16/16 ━━━━━━━━━━━━━━━━━━━━ 15s 894ms/step - loss: 0.0538 - psnr: 13.0895 - val_loss: 0.0533 - val_psnr: 13.0049 Epoch 7/20 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 473ms/step </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/nerf/nerf_11_15.png" /></p> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 16/16 ━━━━━━━━━━━━━━━━━━━━ 16s 940ms/step - loss: 0.0436 - psnr: 13.9857 - val_loss: 0.0381 - val_psnr: 14.4764 Epoch 8/20 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 475ms/step </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/nerf/nerf_11_17.png" /></p> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 16/16 ━━━━━━━━━━━━━━━━━━━━ 15s 919ms/step - loss: 0.0325 - psnr: 15.1856 - val_loss: 0.0294 - val_psnr: 15.5187 Epoch 9/20 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 478ms/step </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/nerf/nerf_11_19.png" /></p> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 16/16 ━━━━━━━━━━━━━━━━━━━━ 16s 927ms/step - loss: 0.0276 - psnr: 15.8105 - val_loss: 0.0259 - val_psnr: 16.0297 Epoch 10/20 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 474ms/step </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/nerf/nerf_11_21.png" /></p> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 16/16 ━━━━━━━━━━━━━━━━━━━━ 16s 952ms/step - loss: 0.0251 - psnr: 16.1994 - val_loss: 0.0252 - val_psnr: 16.0842 Epoch 11/20 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 474ms/step </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/nerf/nerf_11_23.png" /></p> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 16/16 ━━━━━━━━━━━━━━━━━━━━ 15s 909ms/step - loss: 0.0239 - psnr: 16.3749 - val_loss: 0.0228 - val_psnr: 16.5269 Epoch 12/20 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 474ms/step </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/nerf/nerf_11_25.png" /></p> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 16/16 ━━━━━━━━━━━━━━━━━━━━ 19s 1s/step - loss: 0.0215 - psnr: 16.8117 - val_loss: 0.0186 - val_psnr: 17.3930 Epoch 13/20 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 474ms/step </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/nerf/nerf_11_27.png" /></p> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 16/16 ━━━━━━━━━━━━━━━━━━━━ 16s 923ms/step - loss: 0.0188 - psnr: 17.3916 - val_loss: 0.0174 - val_psnr: 17.6570 Epoch 14/20 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 476ms/step </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/nerf/nerf_11_29.png" /></p> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 16/16 ━━━━━━━━━━━━━━━━━━━━ 16s 973ms/step - loss: 0.0175 - psnr: 17.6871 - val_loss: 0.0172 - val_psnr: 17.6644 Epoch 15/20 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 468ms/step </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/nerf/nerf_11_31.png" /></p> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 16/16 ━━━━━━━━━━━━━━━━━━━━ 15s 919ms/step - loss: 0.0172 - psnr: 17.7639 - val_loss: 0.0161 - val_psnr: 18.0313 Epoch 16/20 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 477ms/step </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/nerf/nerf_11_33.png" /></p> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 16/16 ━━━━━━━━━━━━━━━━━━━━ 16s 915ms/step - loss: 0.0150 - psnr: 18.3860 - val_loss: 0.0151 - val_psnr: 18.2832 Epoch 17/20 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 473ms/step </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/nerf/nerf_11_35.png" /></p> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 16/16 ━━━━━━━━━━━━━━━━━━━━ 16s 926ms/step - loss: 0.0154 - psnr: 18.2210 - val_loss: 0.0146 - val_psnr: 18.4284 Epoch 18/20 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 468ms/step </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/nerf/nerf_11_37.png" /></p> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 16/16 ━━━━━━━━━━━━━━━━━━━━ 16s 959ms/step - loss: 0.0145 - psnr: 18.4869 - val_loss: 0.0134 - val_psnr: 18.8039 Epoch 19/20 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 473ms/step </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/nerf/nerf_11_39.png" /></p> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 16/16 ━━━━━━━━━━━━━━━━━━━━ 16s 933ms/step - loss: 0.0136 - psnr: 18.8040 - val_loss: 0.0138 - val_psnr: 18.6680 Epoch 20/20 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 472ms/step </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/nerf/nerf_11_41.png" /></p> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 16/16 ━━━━━━━━━━━━━━━━━━━━ 15s 916ms/step - loss: 0.0131 - psnr: 18.9661 - val_loss: 0.0132 - val_psnr: 18.8687 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 59.40it/s] </code></pre></div> </div> <hr /> <h2 id="visualize-the-training-step">Visualize the training step</h2> <p>Here we see the training step. With the decreasing loss, the rendered image and the depth maps are getting better. In your local system, you will see the <code>training.gif</code> file generated.</p> <p><img alt="training-20" src="https://i.imgur.com/ql5OcYA.gif" /></p> <hr /> <h2 id="inference">Inference</h2> <p>In this section, we ask the model to build novel views of the scene. The model was given <code>106</code> views of the scene in the training step. The collections of training images cannot contain each and every angle of the scene. A trained model can represent the entire 3-D scene with a sparse set of training images.</p> <p>Here we provide different poses to the model and ask for it to give us the 2-D image corresponding to that camera view. If we infer the model for all the 360-degree views, it should provide an overview of the entire scenery from all around.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Get the trained NeRF model and infer.</span> <span class="n">nerf_model</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">nerf_model</span> <span class="n">test_recons_images</span><span class="p">,</span> <span class="n">depth_maps</span> <span class="o">=</span> <span class="n">render_rgb_depth</span><span class="p">(</span> <span class="n">model</span><span class="o">=</span><span class="n">nerf_model</span><span class="p">,</span> <span class="n">rays_flat</span><span class="o">=</span><span class="n">test_rays_flat</span><span class="p">,</span> <span class="n">t_vals</span><span class="o">=</span><span class="n">test_t_vals</span><span class="p">,</span> <span class="n">rand</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">train</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="p">)</span> <span class="c1"># Create subplots.</span> <span class="n">fig</span><span class="p">,</span> <span class="n">axes</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">(</span><span class="n">nrows</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">ncols</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">20</span><span class="p">))</span> <span class="k">for</span> <span class="n">ax</span><span class="p">,</span> <span class="n">ori_img</span><span class="p">,</span> <span class="n">recons_img</span><span class="p">,</span> <span class="n">depth_map</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span> <span class="n">axes</span><span class="p">,</span> <span class="n">test_imgs</span><span class="p">,</span> <span class="n">test_recons_images</span><span class="p">,</span> <span class="n">depth_maps</span> <span class="p">):</span> <span class="n">ax</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">array_to_img</span><span class="p">(</span><span class="n">ori_img</span><span class="p">))</span> <span class="n">ax</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">"Original"</span><span class="p">)</span> <span class="n">ax</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">array_to_img</span><span class="p">(</span><span class="n">recons_img</span><span class="p">))</span> <span class="n">ax</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">"Reconstructed"</span><span class="p">)</span> <span class="n">ax</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">array_to_img</span><span class="p">(</span><span class="n">depth_map</span><span class="p">[</span><span class="o">...</span><span class="p">,</span> <span class="kc">None</span><span class="p">]),</span> <span class="n">cmap</span><span class="o">=</span><span class="s2">"inferno"</span><span class="p">)</span> <span class="n">ax</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">"Depth Map"</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 475ms/step </code></pre></div> </div> <p><img alt="png" src="/img/examples/vision/nerf/nerf_14_1.png" /></p> <hr /> <h2 id="render-3d-scene">Render 3D Scene</h2> <p>Here we will synthesize novel 3D views and stitch all of them together to render a video encompassing the 360-degree view.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">get_translation_t</span><span class="p">(</span><span class="n">t</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Get the translation matrix for movement in t."""</span> <span class="n">matrix</span> <span class="o">=</span> <span class="p">[</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="n">t</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span> <span class="p">]</span> <span class="k">return</span> <span class="n">tf</span><span class="o">.</span><span class="n">convert_to_tensor</span><span class="p">(</span><span class="n">matrix</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="k">def</span> <span class="nf">get_rotation_phi</span><span class="p">(</span><span class="n">phi</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Get the rotation matrix for movement in phi."""</span> <span class="n">matrix</span> <span class="o">=</span> <span class="p">[</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="n">tf</span><span class="o">.</span><span class="n">cos</span><span class="p">(</span><span class="n">phi</span><span class="p">),</span> <span class="o">-</span><span class="n">tf</span><span class="o">.</span><span class="n">sin</span><span class="p">(</span><span class="n">phi</span><span class="p">),</span> <span class="mi">0</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="n">tf</span><span class="o">.</span><span class="n">sin</span><span class="p">(</span><span class="n">phi</span><span class="p">),</span> <span class="n">tf</span><span class="o">.</span><span class="n">cos</span><span class="p">(</span><span class="n">phi</span><span class="p">),</span> <span class="mi">0</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span> <span class="p">]</span> <span class="k">return</span> <span class="n">tf</span><span class="o">.</span><span class="n">convert_to_tensor</span><span class="p">(</span><span class="n">matrix</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="k">def</span> <span class="nf">get_rotation_theta</span><span class="p">(</span><span class="n">theta</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Get the rotation matrix for movement in theta."""</span> <span class="n">matrix</span> <span class="o">=</span> <span class="p">[</span> <span class="p">[</span><span class="n">tf</span><span class="o">.</span><span class="n">cos</span><span class="p">(</span><span class="n">theta</span><span class="p">),</span> <span class="mi">0</span><span class="p">,</span> <span class="o">-</span><span class="n">tf</span><span class="o">.</span><span class="n">sin</span><span class="p">(</span><span class="n">theta</span><span class="p">),</span> <span class="mi">0</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="p">[</span><span class="n">tf</span><span class="o">.</span><span class="n">sin</span><span class="p">(</span><span class="n">theta</span><span class="p">),</span> <span class="mi">0</span><span class="p">,</span> <span class="n">tf</span><span class="o">.</span><span class="n">cos</span><span class="p">(</span><span class="n">theta</span><span class="p">),</span> <span class="mi">0</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">],</span> <span class="p">]</span> <span class="k">return</span> <span class="n">tf</span><span class="o">.</span><span class="n">convert_to_tensor</span><span class="p">(</span><span class="n">matrix</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="k">def</span> <span class="nf">pose_spherical</span><span class="p">(</span><span class="n">theta</span><span class="p">,</span> <span class="n">phi</span><span class="p">,</span> <span class="n">t</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""</span> <span class="sd"> Get the camera to world matrix for the corresponding theta, phi</span> <span class="sd"> and t.</span> <span class="sd"> """</span> <span class="n">c2w</span> <span class="o">=</span> <span class="n">get_translation_t</span><span class="p">(</span><span class="n">t</span><span class="p">)</span> <span class="n">c2w</span> <span class="o">=</span> <span class="n">get_rotation_phi</span><span class="p">(</span><span class="n">phi</span> <span class="o">/</span> <span class="mf">180.0</span> <span class="o">*</span> <span class="n">np</span><span class="o">.</span><span class="n">pi</span><span class="p">)</span> <span class="o">@</span> <span class="n">c2w</span> <span class="n">c2w</span> <span class="o">=</span> <span class="n">get_rotation_theta</span><span class="p">(</span><span class="n">theta</span> <span class="o">/</span> <span class="mf">180.0</span> <span class="o">*</span> <span class="n">np</span><span class="o">.</span><span class="n">pi</span><span class="p">)</span> <span class="o">@</span> <span class="n">c2w</span> <span class="n">c2w</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">([[</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">],</span> <span class="p">[</span><span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">]])</span> <span class="o">@</span> <span class="n">c2w</span> <span class="k">return</span> <span class="n">c2w</span> <span class="n">rgb_frames</span> <span class="o">=</span> <span class="p">[]</span> <span class="n">batch_flat</span> <span class="o">=</span> <span class="p">[]</span> <span class="n">batch_t</span> <span class="o">=</span> <span class="p">[]</span> <span class="c1"># Iterate over different theta value and generate scenes.</span> <span class="k">for</span> <span class="n">index</span><span class="p">,</span> <span class="n">theta</span> <span class="ow">in</span> <span class="n">tqdm</span><span class="p">(</span><span class="nb">enumerate</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">linspace</span><span class="p">(</span><span class="mf">0.0</span><span class="p">,</span> <span class="mf">360.0</span><span class="p">,</span> <span class="mi">120</span><span class="p">,</span> <span class="n">endpoint</span><span class="o">=</span><span class="kc">False</span><span class="p">))):</span> <span class="c1"># Get the camera to world matrix.</span> <span class="n">c2w</span> <span class="o">=</span> <span class="n">pose_spherical</span><span class="p">(</span><span class="n">theta</span><span class="p">,</span> <span class="o">-</span><span class="mf">30.0</span><span class="p">,</span> <span class="mf">4.0</span><span class="p">)</span> <span class="c1">#</span> <span class="n">ray_oris</span><span class="p">,</span> <span class="n">ray_dirs</span> <span class="o">=</span> <span class="n">get_rays</span><span class="p">(</span><span class="n">H</span><span class="p">,</span> <span class="n">W</span><span class="p">,</span> <span class="n">focal</span><span class="p">,</span> <span class="n">c2w</span><span class="p">)</span> <span class="n">rays_flat</span><span class="p">,</span> <span class="n">t_vals</span> <span class="o">=</span> <span class="n">render_flat_rays</span><span class="p">(</span> <span class="n">ray_oris</span><span class="p">,</span> <span class="n">ray_dirs</span><span class="p">,</span> <span class="n">near</span><span class="o">=</span><span class="mf">2.0</span><span class="p">,</span> <span class="n">far</span><span class="o">=</span><span class="mf">6.0</span><span class="p">,</span> <span class="n">num_samples</span><span class="o">=</span><span class="n">NUM_SAMPLES</span><span class="p">,</span> <span class="n">rand</span><span class="o">=</span><span class="kc">False</span> <span class="p">)</span> <span class="k">if</span> <span class="n">index</span> <span class="o">%</span> <span class="n">BATCH_SIZE</span> <span class="o">==</span> <span class="mi">0</span> <span class="ow">and</span> <span class="n">index</span> <span class="o">></span> <span class="mi">0</span><span class="p">:</span> <span class="n">batched_flat</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">batch_flat</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="n">batch_flat</span> <span class="o">=</span> <span class="p">[</span><span class="n">rays_flat</span><span class="p">]</span> <span class="n">batched_t</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">batch_t</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="n">batch_t</span> <span class="o">=</span> <span class="p">[</span><span class="n">t_vals</span><span class="p">]</span> <span class="n">rgb</span><span class="p">,</span> <span class="n">_</span> <span class="o">=</span> <span class="n">render_rgb_depth</span><span class="p">(</span> <span class="n">nerf_model</span><span class="p">,</span> <span class="n">batched_flat</span><span class="p">,</span> <span class="n">batched_t</span><span class="p">,</span> <span class="n">rand</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">train</span><span class="o">=</span><span class="kc">False</span> <span class="p">)</span> <span class="n">temp_rgb</span> <span class="o">=</span> <span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">clip</span><span class="p">(</span><span class="mi">255</span> <span class="o">*</span> <span class="n">img</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mf">255.0</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">uint8</span><span class="p">)</span> <span class="k">for</span> <span class="n">img</span> <span class="ow">in</span> <span class="n">rgb</span><span class="p">]</span> <span class="n">rgb_frames</span> <span class="o">=</span> <span class="n">rgb_frames</span> <span class="o">+</span> <span class="n">temp_rgb</span> <span class="k">else</span><span class="p">:</span> <span class="n">batch_flat</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">rays_flat</span><span class="p">)</span> <span class="n">batch_t</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">t_vals</span><span class="p">)</span> <span class="n">rgb_video</span> <span class="o">=</span> <span class="s2">"rgb_video.mp4"</span> <span class="n">imageio</span><span class="o">.</span><span class="n">mimwrite</span><span class="p">(</span><span class="n">rgb_video</span><span class="p">,</span> <span class="n">rgb_frames</span><span class="p">,</span> <span class="n">fps</span><span class="o">=</span><span class="mi">30</span><span class="p">,</span> <span class="n">quality</span><span class="o">=</span><span class="mi">7</span><span class="p">,</span> <span class="n">macro_block_size</span><span class="o">=</span><span class="kc">None</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>1it [00:01, 1.02s/it] 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 475ms/step 6it [00:03, 1.95it/s] 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 478ms/step 11it [00:05, 2.11it/s] 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 474ms/step 16it [00:07, 2.17it/s] 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 477ms/step 25it [00:10, 3.05it/s] 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 477ms/step 27it [00:12, 2.14it/s] 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 479ms/step 31it [00:14, 2.02it/s] 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 472ms/step 36it [00:16, 2.11it/s] 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 474ms/step 41it [00:18, 2.16it/s] 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 472ms/step 46it [00:21, 2.19it/s] 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 475ms/step 51it [00:23, 2.22it/s] 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 473ms/step 56it [00:25, 2.24it/s] 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 464ms/step 61it [00:27, 2.26it/s] 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 474ms/step 66it [00:29, 2.26it/s] 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 476ms/step 71it [00:32, 2.26it/s] 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 473ms/step 76it [00:34, 2.26it/s] 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 475ms/step 81it [00:36, 2.26it/s] 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 474ms/step 86it [00:38, 2.26it/s] 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 476ms/step 91it [00:40, 2.26it/s] 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 465ms/step 96it [00:43, 2.27it/s] 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 473ms/step 101it [00:45, 2.28it/s] 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 473ms/step 106it [00:47, 2.28it/s] 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 473ms/step 111it [00:49, 2.27it/s] 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 474ms/step 120it [00:52, 2.31it/s] [swscaler @ 0x67626c0] Warning: data is not aligned! This can lead to a speed loss </code></pre></div> </div> <h3 id="visualize-the-video">Visualize the video</h3> <p>Here we can see the rendered 360 degree view of the scene. The model has successfully learned the entire volumetric space through the sparse set of images in <strong>only 20 epochs</strong>. You can view the rendered video saved locally, named <code>rgb_video.mp4</code>.</p> <p><img alt="rendered-video" src="https://i.imgur.com/j2sIkzW.gif" /></p> <hr /> <h2 id="conclusion">Conclusion</h2> <p>We have produced a minimal implementation of NeRF to provide an intuition of its core ideas and methodology. This method has been used in various other works in the computer graphics space.</p> <p>We would like to encourage our readers to use this code as an example and play with the hyperparameters and visualize the outputs. Below we have also provided the outputs of the model trained for more epochs.</p> <table> <thead> <tr> <th style="text-align: left;">Epochs</th> <th style="text-align: center;">GIF of the training step</th> </tr> </thead> <tbody> <tr> <td style="text-align: left;"><strong>100</strong></td> <td style="text-align: center;"><img alt="100-epoch-training" src="https://i.imgur.com/2k9p8ez.gif" /></td> </tr> <tr> <td style="text-align: left;"><strong>200</strong></td> <td style="text-align: center;"><img alt="200-epoch-training" src="https://i.imgur.com/l3rG4HQ.gif" /></td> </tr> </tbody> </table> <hr /> <h2 id="way-forward">Way forward</h2> <p>If anyone is interested to go deeper into NeRF, we have built a 3-part blog series at <a href="https://pyimagesearch.com/">PyImageSearch</a>.</p> <ul> <li><a href="https://www.pyimagesearch.com/2021/11/10/computer-graphics-and-deep-learning-with-nerf-using-tensorflow-and-keras-part-1/">Prerequisites of NeRF</a></li> <li><a href="https://www.pyimagesearch.com/2021/11/17/computer-graphics-and-deep-learning-with-nerf-using-tensorflow-and-keras-part-2/">Concepts of NeRF</a></li> <li><a href="https://www.pyimagesearch.com/2021/11/24/computer-graphics-and-deep-learning-with-nerf-using-tensorflow-and-keras-part-3/">Implementing NeRF</a></li> </ul> <hr /> <h2 id="reference">Reference</h2> <ul> <li><a href="https://github.com/bmild/nerf">NeRF repository</a>: The official repository for NeRF.</li> <li><a href="https://arxiv.org/abs/2003.08934">NeRF paper</a>: The paper on NeRF.</li> <li><a href="https://github.com/3b1b/manim">Manim Repository</a>: We have used manim to build all the animations.</li> <li><a href="https://www.mathworks.com/help/vision/ug/camera-calibration.html">Mathworks</a>: Mathworks for the camera calibration article.</li> <li><a href="https://www.youtube.com/watch?v=dPWLybp4LL0">Mathew's video</a>: A great video on NeRF.</li> </ul> <p>You can try the model on <a href="https://huggingface.co/spaces/keras-io/NeRF">Hugging Face Spaces</a>.</p> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#3d-volumetric-rendering-with-nerf'>3D volumetric rendering with NeRF</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#setup'>Setup</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#download-and-load-the-data'>Download and load the data</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#data-pipeline'>Data pipeline</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#nerf-model'>NeRF model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#training'>Training</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#visualize-the-training-step'>Visualize the training step</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#inference'>Inference</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#render-3d-scene'>Render 3D Scene</a> </div> <div class='k-outline-depth-3'> <a href='#visualize-the-video'>Visualize the video</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#conclusion'>Conclusion</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#way-forward'>Way forward</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#reference'>Reference</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>