CINXE.COM
Writing Keras Models With TensorFlow NumPy
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/keras_recipes/tensorflow_numpy_models/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Writing Keras Models With TensorFlow NumPy"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Writing Keras Models With TensorFlow NumPy"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Writing Keras Models With TensorFlow NumPy</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink active" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-sublink2" href="/examples/keras_recipes/parameter_efficient_finetuning_of_gemma_with_lora_and_qlora/">Parameter-efficient fine-tuning of Gemma with LoRA and QLoRA</a> <a class="nav-sublink2" href="/examples/keras_recipes/float8_training_and_inference_with_transformer/">Float8 training and inference with a simple Transformer model</a> <a class="nav-sublink2" href="/examples/keras_recipes/tf_serving/">Serving TensorFlow models with TFServing</a> <a class="nav-sublink2" href="/examples/keras_recipes/debugging_tips/">Keras debugging tips</a> <a class="nav-sublink2" href="/examples/keras_recipes/subclassing_conv_layers/">Customizing the convolution operation of a Conv2D layer</a> <a class="nav-sublink2" href="/examples/keras_recipes/trainer_pattern/">Trainer pattern</a> <a class="nav-sublink2" href="/examples/keras_recipes/endpoint_layer_pattern/">Endpoint layer pattern</a> <a class="nav-sublink2" href="/examples/keras_recipes/reproducibility_recipes/">Reproducibility in Keras Models</a> <a class="nav-sublink2 active" href="/examples/keras_recipes/tensorflow_numpy_models/">Writing Keras Models With TensorFlow NumPy</a> <a class="nav-sublink2" href="/examples/keras_recipes/antirectifier/">Simple custom layer example: Antirectifier</a> <a class="nav-sublink2" href="/examples/keras_recipes/sample_size_estimate/">Estimating required sample size for model training</a> <a class="nav-sublink2" href="/examples/keras_recipes/memory_efficient_embeddings/">Memory-efficient embeddings for recommendation systems</a> <a class="nav-sublink2" href="/examples/keras_recipes/creating_tfrecords/">Creating TFRecords</a> <a class="nav-sublink2" href="/examples/keras_recipes/packaging_keras_models_for_wide_distribution/">Packaging Keras models for wide distribution using Functional Subclassing</a> <a class="nav-sublink2" href="/examples/keras_recipes/approximating_non_function_mappings/">Approximating non-Function Mappings with Mixture Density Networks</a> <a class="nav-sublink2" href="/examples/keras_recipes/bayesian_neural_networks/">Probabilistic Bayesian Neural Networks</a> <a class="nav-sublink2" href="/examples/keras_recipes/better_knowledge_distillation/">Knowledge distillation recipes</a> <a class="nav-sublink2" href="/examples/keras_recipes/sklearn_metric_callbacks/">Evaluating and exporting scikit-learn metrics in a Keras callback</a> <a class="nav-sublink2" href="/examples/keras_recipes/tfrecord/">How to train a Keras model on TFRecord files</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparameter Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> <a class="nav-link" href="/keras_cv/" role="tab" aria-selected="">KerasCV: Computer Vision Workflows</a> <a class="nav-link" href="/keras_nlp/" role="tab" aria-selected="">KerasNLP: Natural Language Workflows</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/keras_recipes/'>Quick Keras Recipes</a> / Writing Keras Models With TensorFlow NumPy </div> <div class='k-content'> <h1 id="writing-keras-models-with-tensorflow-numpy">Writing Keras Models With TensorFlow NumPy</h1> <p><strong>Author:</strong> <a href="https://lukewood.xyz">lukewood</a><br> <strong>Date created:</strong> 2021/08/28<br> <strong>Last modified:</strong> 2021/08/28<br> <strong>Description:</strong> Overview of how to use the TensorFlow NumPy API to write Keras models.</p> <div class='example_version_banner keras_3'>ⓘ This example uses Keras 3</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/keras_recipes/ipynb/tensorflow_numpy_models.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/examples/keras_recipes/tensorflow_numpy_models.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="introduction">Introduction</h2> <p><a href="https://numpy.org/">NumPy</a> is a hugely successful Python linear algebra library.</p> <p>TensorFlow recently launched <a href="https://www.tensorflow.org/guide/tf_numpy">tf_numpy</a>, a TensorFlow implementation of a large subset of the NumPy API. Thanks to <code>tf_numpy</code>, you can write Keras layers or models in the NumPy style!</p> <p>The TensorFlow NumPy API has full integration with the TensorFlow ecosystem. Features such as automatic differentiation, TensorBoard, Keras model callbacks, TPU distribution and model exporting are all supported.</p> <p>Let's run through a few examples.</p> <hr /> <h2 id="setup">Setup</h2> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">os</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s2">"KERAS_BACKEND"</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"tensorflow"</span> <span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="nn">tf</span> <span class="kn">import</span> <span class="nn">tensorflow.experimental.numpy</span> <span class="k">as</span> <span class="nn">tnp</span> <span class="kn">import</span> <span class="nn">keras</span> <span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">layers</span> </code></pre></div> <p>To test our models we will use the Boston housing prices regression dataset.</p> <div class="codehilite"><pre><span></span><code><span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">),</span> <span class="p">(</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">)</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">datasets</span><span class="o">.</span><span class="n">boston_housing</span><span class="o">.</span><span class="n">load_data</span><span class="p">(</span> <span class="n">path</span><span class="o">=</span><span class="s2">"boston_housing.npz"</span><span class="p">,</span> <span class="n">test_split</span><span class="o">=</span><span class="mf">0.2</span><span class="p">,</span> <span class="n">seed</span><span class="o">=</span><span class="mi">113</span> <span class="p">)</span> <span class="n">input_dim</span> <span class="o">=</span> <span class="n">x_train</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="k">def</span> <span class="nf">evaluate_model</span><span class="p">(</span><span class="n">model</span><span class="p">:</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">):</span> <span class="n">loss</span><span class="p">,</span> <span class="n">percent_error</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">,</span> <span class="n">verbose</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Mean absolute percent error before training: "</span><span class="p">,</span> <span class="n">percent_error</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">200</span><span class="p">,</span> <span class="n">verbose</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="n">loss</span><span class="p">,</span> <span class="n">percent_error</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">x_test</span><span class="p">,</span> <span class="n">y_test</span><span class="p">,</span> <span class="n">verbose</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Mean absolute percent error after training:"</span><span class="p">,</span> <span class="n">percent_error</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/california_housing.npz 743530/743530 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step </code></pre></div> </div> <hr /> <h2 id="subclassing-kerasmodel-with-tnp">Subclassing keras.Model with TNP</h2> <p>The most flexible way to make use of the Keras API is to subclass the <a href="https://keras.io/api/models/model/">[<code>keras.Model</code>](/api/models/model#model-class)</a> class. Subclassing the Model class gives you the ability to fully customize what occurs in the training loop. This makes subclassing Model a popular option for researchers.</p> <p>In this example, we will implement a <code>Model</code> subclass that performs regression over the boston housing dataset using the TNP API. Note that differentiation and gradient descent is handled automatically when using the TNP API alongside keras.</p> <p>First let's define a simple <code>TNPForwardFeedRegressionNetwork</code> class.</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">TNPForwardFeedRegressionNetwork</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">blocks</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="o">**</span><span class="n">kwargs</span><span class="p">)</span> <span class="k">if</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">blocks</span><span class="p">,</span> <span class="nb">list</span><span class="p">):</span> <span class="k">raise</span> <span class="ne">ValueError</span><span class="p">(</span><span class="sa">f</span><span class="s2">"blocks must be a list, got blocks=</span><span class="si">{</span><span class="n">blocks</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">blocks</span> <span class="o">=</span> <span class="n">blocks</span> <span class="bp">self</span><span class="o">.</span><span class="n">block_weights</span> <span class="o">=</span> <span class="kc">None</span> <span class="bp">self</span><span class="o">.</span><span class="n">biases</span> <span class="o">=</span> <span class="kc">None</span> <span class="k">def</span> <span class="nf">build</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_shape</span><span class="p">):</span> <span class="n">current_shape</span> <span class="o">=</span> <span class="n">input_shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="bp">self</span><span class="o">.</span><span class="n">block_weights</span> <span class="o">=</span> <span class="p">[]</span> <span class="bp">self</span><span class="o">.</span><span class="n">biases</span> <span class="o">=</span> <span class="p">[]</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">block</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">blocks</span><span class="p">):</span> <span class="bp">self</span><span class="o">.</span><span class="n">block_weights</span><span class="o">.</span><span class="n">append</span><span class="p">(</span> <span class="bp">self</span><span class="o">.</span><span class="n">add_weight</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">current_shape</span><span class="p">,</span> <span class="n">block</span><span class="p">),</span> <span class="n">trainable</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="sa">f</span><span class="s2">"block-</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s2">"</span><span class="p">,</span> <span class="n">initializer</span><span class="o">=</span><span class="s2">"glorot_normal"</span><span class="p">,</span> <span class="p">)</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">biases</span><span class="o">.</span><span class="n">append</span><span class="p">(</span> <span class="bp">self</span><span class="o">.</span><span class="n">add_weight</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">block</span><span class="p">,),</span> <span class="n">trainable</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="sa">f</span><span class="s2">"bias-</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s2">"</span><span class="p">,</span> <span class="n">initializer</span><span class="o">=</span><span class="s2">"zeros"</span><span class="p">,</span> <span class="p">)</span> <span class="p">)</span> <span class="n">current_shape</span> <span class="o">=</span> <span class="n">block</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear_layer</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">add_weight</span><span class="p">(</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">current_shape</span><span class="p">,</span> <span class="mi">1</span><span class="p">),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"linear_projector"</span><span class="p">,</span> <span class="n">trainable</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">initializer</span><span class="o">=</span><span class="s2">"glorot_normal"</span><span class="p">,</span> <span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="n">activations</span> <span class="o">=</span> <span class="n">inputs</span> <span class="k">for</span> <span class="n">w</span><span class="p">,</span> <span class="n">b</span> <span class="ow">in</span> <span class="nb">zip</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">block_weights</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">biases</span><span class="p">):</span> <span class="n">activations</span> <span class="o">=</span> <span class="n">tnp</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">activations</span><span class="p">,</span> <span class="n">w</span><span class="p">)</span> <span class="o">+</span> <span class="n">b</span> <span class="c1"># ReLu activation function</span> <span class="n">activations</span> <span class="o">=</span> <span class="n">tnp</span><span class="o">.</span><span class="n">maximum</span><span class="p">(</span><span class="n">activations</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">)</span> <span class="k">return</span> <span class="n">tnp</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">activations</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">linear_layer</span><span class="p">)</span> </code></pre></div> <p>Just like with any other Keras model we can utilize any supported optimizer, loss, metrics or callbacks that we want.</p> <p>Let's see how the model performs!</p> <div class="codehilite"><pre><span></span><code><span class="n">model</span> <span class="o">=</span> <span class="n">TNPForwardFeedRegressionNetwork</span><span class="p">(</span><span class="n">blocks</span><span class="o">=</span><span class="p">[</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">])</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">optimizer</span><span class="o">=</span><span class="s2">"adam"</span><span class="p">,</span> <span class="n">loss</span><span class="o">=</span><span class="s2">"mean_squared_error"</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">MeanAbsolutePercentageError</span><span class="p">()],</span> <span class="p">)</span> <span class="n">evaluate_model</span><span class="p">(</span><span class="n">model</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1699909864.025985 48611 device_compiler.h:187] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process. Mean absolute percent error before training: 99.96772766113281 Mean absolute percent error after training: 40.94866180419922 </code></pre></div> </div> <p>Great! Our model seems to be effectively learning to solve the problem at hand.</p> <p>We can also write our own custom loss function using TNP.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">tnp_mse</span><span class="p">(</span><span class="n">y_true</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">):</span> <span class="k">return</span> <span class="n">tnp</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">tnp</span><span class="o">.</span><span class="n">square</span><span class="p">(</span><span class="n">y_true</span> <span class="o">-</span> <span class="n">y_pred</span><span class="p">),</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="n">keras</span><span class="o">.</span><span class="n">backend</span><span class="o">.</span><span class="n">clear_session</span><span class="p">()</span> <span class="n">model</span> <span class="o">=</span> <span class="n">TNPForwardFeedRegressionNetwork</span><span class="p">(</span><span class="n">blocks</span><span class="o">=</span><span class="p">[</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">])</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">optimizer</span><span class="o">=</span><span class="s2">"adam"</span><span class="p">,</span> <span class="n">loss</span><span class="o">=</span><span class="n">tnp_mse</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">MeanAbsolutePercentageError</span><span class="p">()],</span> <span class="p">)</span> <span class="n">evaluate_model</span><span class="p">(</span><span class="n">model</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Mean absolute percent error before training: 99.99896240234375 Mean absolute percent error after training: 52.533199310302734 </code></pre></div> </div> <hr /> <h2 id="implementing-a-keras-layer-based-model-with-tnp">Implementing a Keras Layer Based Model with TNP</h2> <p>If desired, TNP can also be used in layer oriented Keras code structure. Let's implement the same model, but using a layered approach!</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">tnp_relu</span><span class="p">(</span><span class="n">x</span><span class="p">):</span> <span class="k">return</span> <span class="n">tnp</span><span class="o">.</span><span class="n">maximum</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="mi">0</span><span class="p">)</span> <span class="k">class</span> <span class="nc">TNPDense</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">units</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">units</span> <span class="o">=</span> <span class="n">units</span> <span class="bp">self</span><span class="o">.</span><span class="n">activation</span> <span class="o">=</span> <span class="n">activation</span> <span class="k">def</span> <span class="nf">build</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_shape</span><span class="p">):</span> <span class="bp">self</span><span class="o">.</span><span class="n">w</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">add_weight</span><span class="p">(</span> <span class="n">name</span><span class="o">=</span><span class="s2">"weights"</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">input_shape</span><span class="p">[</span><span class="mi">1</span><span class="p">],</span> <span class="bp">self</span><span class="o">.</span><span class="n">units</span><span class="p">),</span> <span class="n">initializer</span><span class="o">=</span><span class="s2">"random_normal"</span><span class="p">,</span> <span class="n">trainable</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">add_weight</span><span class="p">(</span> <span class="n">name</span><span class="o">=</span><span class="s2">"bias"</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">units</span><span class="p">,),</span> <span class="n">initializer</span><span class="o">=</span><span class="s2">"zeros"</span><span class="p">,</span> <span class="n">trainable</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">tnp</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">w</span><span class="p">)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">bias</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">activation</span><span class="p">:</span> <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">activation</span><span class="p">(</span><span class="n">outputs</span><span class="p">)</span> <span class="k">return</span> <span class="n">outputs</span> <span class="k">def</span> <span class="nf">create_layered_tnp_model</span><span class="p">():</span> <span class="k">return</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span> <span class="p">[</span> <span class="n">TNPDense</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="n">tnp_relu</span><span class="p">),</span> <span class="n">TNPDense</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="n">tnp_relu</span><span class="p">),</span> <span class="n">TNPDense</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="p">]</span> <span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">create_layered_tnp_model</span><span class="p">()</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">optimizer</span><span class="o">=</span><span class="s2">"adam"</span><span class="p">,</span> <span class="n">loss</span><span class="o">=</span><span class="s2">"mean_squared_error"</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">MeanAbsolutePercentageError</span><span class="p">()],</span> <span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">build</span><span class="p">((</span><span class="kc">None</span><span class="p">,</span> <span class="n">input_dim</span><span class="p">))</span> <span class="n">model</span><span class="o">.</span><span class="n">summary</span><span class="p">()</span> <span class="n">evaluate_model</span><span class="p">(</span><span class="n">model</span><span class="p">)</span> </code></pre></div> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold">Model: "sequential"</span> </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ ┃<span style="font-weight: bold"> Layer (type) </span>┃<span style="font-weight: bold"> Output Shape </span>┃<span style="font-weight: bold"> Param # </span>┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ │ tnp_dense (<span style="color: #0087ff; text-decoration-color: #0087ff">TNPDense</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">3</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">27</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ tnp_dense_1 (<span style="color: #0087ff; text-decoration-color: #0087ff">TNPDense</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">3</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">12</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ tnp_dense_2 (<span style="color: #0087ff; text-decoration-color: #0087ff">TNPDense</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">1</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">4</span> │ └─────────────────────────────────┴───────────────────────────┴────────────┘ </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Total params: </span><span style="color: #00af00; text-decoration-color: #00af00">43</span> (172.00 B) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">43</span> (172.00 B) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Non-trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">0</span> (0.00 B) </pre> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Mean absolute percent error before training: 100.00006866455078 Mean absolute percent error after training: 43.57806396484375 </code></pre></div> </div> <p>You can also seamlessly switch between TNP layers and native Keras layers!</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">create_mixed_model</span><span class="p">():</span> <span class="k">return</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span> <span class="p">[</span> <span class="n">TNPDense</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="n">tnp_relu</span><span class="p">),</span> <span class="c1"># The model will have no issue using a normal Dense layer</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">),</span> <span class="c1"># ... or switching back to tnp layers!</span> <span class="n">TNPDense</span><span class="p">(</span><span class="mi">1</span><span class="p">),</span> <span class="p">]</span> <span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">create_mixed_model</span><span class="p">()</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">optimizer</span><span class="o">=</span><span class="s2">"adam"</span><span class="p">,</span> <span class="n">loss</span><span class="o">=</span><span class="s2">"mean_squared_error"</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">MeanAbsolutePercentageError</span><span class="p">()],</span> <span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">build</span><span class="p">((</span><span class="kc">None</span><span class="p">,</span> <span class="n">input_dim</span><span class="p">))</span> <span class="n">model</span><span class="o">.</span><span class="n">summary</span><span class="p">()</span> <span class="n">evaluate_model</span><span class="p">(</span><span class="n">model</span><span class="p">)</span> </code></pre></div> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold">Model: "sequential_1"</span> </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ ┃<span style="font-weight: bold"> Layer (type) </span>┃<span style="font-weight: bold"> Output Shape </span>┃<span style="font-weight: bold"> Param # </span>┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ │ tnp_dense_3 (<span style="color: #0087ff; text-decoration-color: #0087ff">TNPDense</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">3</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">27</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ dense (<span style="color: #0087ff; text-decoration-color: #0087ff">Dense</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">3</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">12</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ tnp_dense_4 (<span style="color: #0087ff; text-decoration-color: #0087ff">TNPDense</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">1</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">4</span> │ └─────────────────────────────────┴───────────────────────────┴────────────┘ </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Total params: </span><span style="color: #00af00; text-decoration-color: #00af00">43</span> (172.00 B) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">43</span> (172.00 B) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Non-trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">0</span> (0.00 B) </pre> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Mean absolute percent error before training: 100.0 Mean absolute percent error after training: 55.646610260009766 </code></pre></div> </div> <p>The Keras API offers a wide variety of layers. The ability to use them alongside NumPy code can be a huge time saver in projects.</p> <hr /> <h2 id="distribution-strategy">Distribution Strategy</h2> <p>TensorFlow NumPy and Keras integrate with <a href="https://www.tensorflow.org/guide/distributed_training">TensorFlow Distribution Strategies</a>. This makes it simple to perform distributed training across multiple GPUs, or even an entire TPU Pod.</p> <div class="codehilite"><pre><span></span><code><span class="n">gpus</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">config</span><span class="o">.</span><span class="n">list_logical_devices</span><span class="p">(</span><span class="s2">"GPU"</span><span class="p">)</span> <span class="k">if</span> <span class="n">gpus</span><span class="p">:</span> <span class="n">strategy</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">distribute</span><span class="o">.</span><span class="n">MirroredStrategy</span><span class="p">(</span><span class="n">gpus</span><span class="p">)</span> <span class="k">else</span><span class="p">:</span> <span class="c1"># We can fallback to a no-op CPU strategy.</span> <span class="n">strategy</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">distribute</span><span class="o">.</span><span class="n">get_strategy</span><span class="p">()</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Running with strategy:"</span><span class="p">,</span> <span class="nb">str</span><span class="p">(</span><span class="n">strategy</span><span class="o">.</span><span class="vm">__class__</span><span class="o">.</span><span class="vm">__name__</span><span class="p">))</span> <span class="k">with</span> <span class="n">strategy</span><span class="o">.</span><span class="n">scope</span><span class="p">():</span> <span class="n">model</span> <span class="o">=</span> <span class="n">create_layered_tnp_model</span><span class="p">()</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">optimizer</span><span class="o">=</span><span class="s2">"adam"</span><span class="p">,</span> <span class="n">loss</span><span class="o">=</span><span class="s2">"mean_squared_error"</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">MeanAbsolutePercentageError</span><span class="p">()],</span> <span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">build</span><span class="p">((</span><span class="kc">None</span><span class="p">,</span> <span class="n">input_dim</span><span class="p">))</span> <span class="n">model</span><span class="o">.</span><span class="n">summary</span><span class="p">()</span> <span class="n">evaluate_model</span><span class="p">(</span><span class="n">model</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',) Running with strategy: MirroredStrategy </code></pre></div> </div> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold">Model: "sequential_2"</span> </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓ ┃<span style="font-weight: bold"> Layer (type) </span>┃<span style="font-weight: bold"> Output Shape </span>┃<span style="font-weight: bold"> Param # </span>┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩ │ tnp_dense_5 (<span style="color: #0087ff; text-decoration-color: #0087ff">TNPDense</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">3</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">27</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ tnp_dense_6 (<span style="color: #0087ff; text-decoration-color: #0087ff">TNPDense</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">3</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">12</span> │ ├─────────────────────────────────┼───────────────────────────┼────────────┤ │ tnp_dense_7 (<span style="color: #0087ff; text-decoration-color: #0087ff">TNPDense</span>) │ (<span style="color: #00d7ff; text-decoration-color: #00d7ff">None</span>, <span style="color: #00af00; text-decoration-color: #00af00">1</span>) │ <span style="color: #00af00; text-decoration-color: #00af00">4</span> │ └─────────────────────────────────┴───────────────────────────┴────────────┘ </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Total params: </span><span style="color: #00af00; text-decoration-color: #00af00">43</span> (172.00 B) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">43</span> (172.00 B) </pre> <pre style="white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace"><span style="font-weight: bold"> Non-trainable params: </span><span style="color: #00af00; text-decoration-color: #00af00">0</span> (0.00 B) </pre> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Mean absolute percent error before training: 100.0 Mean absolute percent error after training: 44.573463439941406 </code></pre></div> </div> <hr /> <h2 id="tensorboard-integration">TensorBoard Integration</h2> <p>One of the many benefits of using the Keras API is the ability to monitor training through TensorBoard. Using the TensorFlow NumPy API alongside Keras allows you to easily leverage TensorBoard.</p> <div class="codehilite"><pre><span></span><code><span class="n">keras</span><span class="o">.</span><span class="n">backend</span><span class="o">.</span><span class="n">clear_session</span><span class="p">()</span> </code></pre></div> <p>To load the TensorBoard from a Jupyter notebook, you can run the following magic:</p> <div class="codehilite"><pre><span></span><code>%load_ext tensorboard </code></pre></div> <div class="codehilite"><pre><span></span><code><span class="n">models</span> <span class="o">=</span> <span class="p">[</span> <span class="p">(</span> <span class="n">TNPForwardFeedRegressionNetwork</span><span class="p">(</span><span class="n">blocks</span><span class="o">=</span><span class="p">[</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">]),</span> <span class="s2">"TNPForwardFeedRegressionNetwork"</span><span class="p">,</span> <span class="p">),</span> <span class="p">(</span><span class="n">create_layered_tnp_model</span><span class="p">(),</span> <span class="s2">"layered_tnp_model"</span><span class="p">),</span> <span class="p">(</span><span class="n">create_mixed_model</span><span class="p">(),</span> <span class="s2">"mixed_model"</span><span class="p">),</span> <span class="p">]</span> <span class="k">for</span> <span class="n">model</span><span class="p">,</span> <span class="n">model_name</span> <span class="ow">in</span> <span class="n">models</span><span class="p">:</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">optimizer</span><span class="o">=</span><span class="s2">"adam"</span><span class="p">,</span> <span class="n">loss</span><span class="o">=</span><span class="s2">"mean_squared_error"</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">MeanAbsolutePercentageError</span><span class="p">()],</span> <span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span> <span class="n">x_train</span><span class="p">,</span> <span class="n">y_train</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">200</span><span class="p">,</span> <span class="n">verbose</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">callbacks</span><span class="o">=</span><span class="p">[</span><span class="n">keras</span><span class="o">.</span><span class="n">callbacks</span><span class="o">.</span><span class="n">TensorBoard</span><span class="p">(</span><span class="n">log_dir</span><span class="o">=</span><span class="sa">f</span><span class="s2">"logs/</span><span class="si">{</span><span class="n">model_name</span><span class="si">}</span><span class="s2">"</span><span class="p">)],</span> <span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>/opt/conda/envs/keras-tensorflow/lib/python3.10/site-packages/keras/src/callbacks/tensorboard.py:676: UserWarning: Model failed to serialize as JSON. Ignoring... Invalid format specifier warnings.warn(f"Model failed to serialize as JSON. Ignoring... {exc}") </code></pre></div> </div> <p>To load the TensorBoard from a Jupyter notebook you can use the <code>%tensorboard</code> magic:</p> <div class="codehilite"><pre><span></span><code>%tensorboard --logdir logs </code></pre></div> <p>The TensorBoard monitor metrics and examine the training curve.</p> <p><img alt="Tensorboard training graph" src="https://i.imgur.com/wsOuFnz.png" /></p> <p>The TensorBoard also allows you to explore the computation graph used in your models.</p> <p><img alt="Tensorboard graph exploration" src="https://i.imgur.com/tOrezDL.png" /></p> <p>The ability to introspect into your models can be valuable during debugging.</p> <hr /> <h2 id="conclusion">Conclusion</h2> <p>Porting existing NumPy code to Keras models using the <code>tensorflow_numpy</code> API is easy! By integrating with Keras you gain the ability to use existing Keras callbacks, metrics and optimizers, easily distribute your training and use Tensorboard.</p> <p>Migrating a more complex model, such as a ResNet, to the TensorFlow NumPy API would be a great follow up learning exercise.</p> <p>Several open source NumPy ResNet implementations are available online.</p> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#writing-keras-models-with-tensorflow-numpy'>Writing Keras Models With TensorFlow NumPy</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#setup'>Setup</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#subclassing-kerasmodel-with-tnp'>Subclassing keras.Model with TNP</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#implementing-a-keras-layer-based-model-with-tnp'>Implementing a Keras Layer Based Model with TNP</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#distribution-strategy'>Distribution Strategy</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#tensorboard-integration'>TensorBoard Integration</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#conclusion'>Conclusion</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>