CINXE.COM
Classification with Neural Decision Forests
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/structured_data/deep_neural_decision_forests/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Classification with Neural Decision Forests"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Classification with Neural Decision Forests"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Classification with Neural Decision Forests</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink active" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink2" href="/examples/structured_data/structured_data_classification_with_feature_space/">Structured data classification with FeatureSpace</a> <a class="nav-sublink2" href="/examples/structured_data/feature_space_advanced/">FeatureSpace advanced use cases</a> <a class="nav-sublink2" href="/examples/structured_data/imbalanced_classification/">Imbalanced classification: credit card fraud detection</a> <a class="nav-sublink2" href="/examples/structured_data/structured_data_classification_from_scratch/">Structured data classification from scratch</a> <a class="nav-sublink2" href="/examples/structured_data/wide_deep_cross_networks/">Structured data learning with Wide, Deep, and Cross networks</a> <a class="nav-sublink2" href="/examples/structured_data/classification_with_grn_and_vsn/">Classification with Gated Residual and Variable Selection Networks</a> <a class="nav-sublink2" href="/examples/structured_data/classification_with_tfdf/">Classification with TensorFlow Decision Forests</a> <a class="nav-sublink2 active" href="/examples/structured_data/deep_neural_decision_forests/">Classification with Neural Decision Forests</a> <a class="nav-sublink2" href="/examples/structured_data/tabtransformer/">Structured data learning with TabTransformer</a> <a class="nav-sublink2" href="/examples/structured_data/collaborative_filtering_movielens/">Collaborative Filtering for Movie Recommendations</a> <a class="nav-sublink2" href="/examples/structured_data/movielens_recommendations_transformers/">A Transformer-based recommendation system</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparameter Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> <a class="nav-link" href="/keras_cv/" role="tab" aria-selected="">KerasCV: Computer Vision Workflows</a> <a class="nav-link" href="/keras_nlp/" role="tab" aria-selected="">KerasNLP: Natural Language Workflows</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/structured_data/'>Structured Data</a> / Classification with Neural Decision Forests </div> <div class='k-content'> <h1 id="classification-with-neural-decision-forests">Classification with Neural Decision Forests</h1> <p><strong>Author:</strong> <a href="https://www.linkedin.com/in/khalid-salama-24403144/">Khalid Salama</a><br> <strong>Date created:</strong> 2021/01/15<br> <strong>Last modified:</strong> 2021/01/15<br> <strong>Description:</strong> How to train differentiable decision trees for end-to-end learning in deep neural networks.</p> <div class='example_version_banner keras_3'>ⓘ This example uses Keras 3</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/structured_data/ipynb/deep_neural_decision_forests.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/examples/structured_data/deep_neural_decision_forests.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="introduction">Introduction</h2> <p>This example provides an implementation of the <a href="https://ieeexplore.ieee.org/document/7410529">Deep Neural Decision Forest</a> model introduced by P. Kontschieder et al. for structured data classification. It demonstrates how to build a stochastic and differentiable decision tree model, train it end-to-end, and unify decision trees with deep representation learning.</p> <hr /> <h2 id="the-dataset">The dataset</h2> <p>This example uses the <a href="https://archive.ics.uci.edu/ml/datasets/census+income">United States Census Income Dataset</a> provided by the <a href="https://archive.ics.uci.edu/ml/index.php">UC Irvine Machine Learning Repository</a>. The task is binary classification to predict whether a person is likely to be making over USD 50,000 a year.</p> <p>The dataset includes 48,842 instances with 14 input features (such as age, work class, education, occupation, and so on): 5 numerical features and 9 categorical features.</p> <hr /> <h2 id="setup">Setup</h2> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">keras</span> <span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">layers</span> <span class="kn">from</span> <span class="nn">keras.layers</span> <span class="kn">import</span> <span class="n">StringLookup</span> <span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">ops</span> <span class="kn">from</span> <span class="nn">tensorflow</span> <span class="kn">import</span> <span class="n">data</span> <span class="k">as</span> <span class="n">tf_data</span> <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> <span class="kn">import</span> <span class="nn">pandas</span> <span class="k">as</span> <span class="nn">pd</span> <span class="kn">import</span> <span class="nn">math</span> </code></pre></div> <hr /> <h2 id="prepare-the-data">Prepare the data</h2> <div class="codehilite"><pre><span></span><code><span class="n">CSV_HEADER</span> <span class="o">=</span> <span class="p">[</span> <span class="s2">"age"</span><span class="p">,</span> <span class="s2">"workclass"</span><span class="p">,</span> <span class="s2">"fnlwgt"</span><span class="p">,</span> <span class="s2">"education"</span><span class="p">,</span> <span class="s2">"education_num"</span><span class="p">,</span> <span class="s2">"marital_status"</span><span class="p">,</span> <span class="s2">"occupation"</span><span class="p">,</span> <span class="s2">"relationship"</span><span class="p">,</span> <span class="s2">"race"</span><span class="p">,</span> <span class="s2">"gender"</span><span class="p">,</span> <span class="s2">"capital_gain"</span><span class="p">,</span> <span class="s2">"capital_loss"</span><span class="p">,</span> <span class="s2">"hours_per_week"</span><span class="p">,</span> <span class="s2">"native_country"</span><span class="p">,</span> <span class="s2">"income_bracket"</span><span class="p">,</span> <span class="p">]</span> <span class="n">train_data_url</span> <span class="o">=</span> <span class="p">(</span> <span class="s2">"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data"</span> <span class="p">)</span> <span class="n">train_data</span> <span class="o">=</span> <span class="n">pd</span><span class="o">.</span><span class="n">read_csv</span><span class="p">(</span><span class="n">train_data_url</span><span class="p">,</span> <span class="n">header</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">names</span><span class="o">=</span><span class="n">CSV_HEADER</span><span class="p">)</span> <span class="n">test_data_url</span> <span class="o">=</span> <span class="p">(</span> <span class="s2">"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test"</span> <span class="p">)</span> <span class="n">test_data</span> <span class="o">=</span> <span class="n">pd</span><span class="o">.</span><span class="n">read_csv</span><span class="p">(</span><span class="n">test_data_url</span><span class="p">,</span> <span class="n">header</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">names</span><span class="o">=</span><span class="n">CSV_HEADER</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Train dataset shape: </span><span class="si">{</span><span class="n">train_data</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Test dataset shape: </span><span class="si">{</span><span class="n">test_data</span><span class="o">.</span><span class="n">shape</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Train dataset shape: (32561, 15) Test dataset shape: (16282, 15) </code></pre></div> </div> <p>Remove the first record (because it is not a valid data example) and a trailing 'dot' in the class labels.</p> <div class="codehilite"><pre><span></span><code><span class="n">test_data</span> <span class="o">=</span> <span class="n">test_data</span><span class="p">[</span><span class="mi">1</span><span class="p">:]</span> <span class="n">test_data</span><span class="o">.</span><span class="n">income_bracket</span> <span class="o">=</span> <span class="n">test_data</span><span class="o">.</span><span class="n">income_bracket</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span> <span class="k">lambda</span> <span class="n">value</span><span class="p">:</span> <span class="n">value</span><span class="o">.</span><span class="n">replace</span><span class="p">(</span><span class="s2">"."</span><span class="p">,</span> <span class="s2">""</span><span class="p">)</span> <span class="p">)</span> </code></pre></div> <p>We store the training and test data splits locally as CSV files.</p> <div class="codehilite"><pre><span></span><code><span class="n">train_data_file</span> <span class="o">=</span> <span class="s2">"train_data.csv"</span> <span class="n">test_data_file</span> <span class="o">=</span> <span class="s2">"test_data.csv"</span> <span class="n">train_data</span><span class="o">.</span><span class="n">to_csv</span><span class="p">(</span><span class="n">train_data_file</span><span class="p">,</span> <span class="n">index</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">header</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> <span class="n">test_data</span><span class="o">.</span><span class="n">to_csv</span><span class="p">(</span><span class="n">test_data_file</span><span class="p">,</span> <span class="n">index</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">header</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="define-dataset-metadata">Define dataset metadata</h2> <p>Here, we define the metadata of the dataset that will be useful for reading and parsing and encoding input features.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># A list of the numerical feature names.</span> <span class="n">NUMERIC_FEATURE_NAMES</span> <span class="o">=</span> <span class="p">[</span> <span class="s2">"age"</span><span class="p">,</span> <span class="s2">"education_num"</span><span class="p">,</span> <span class="s2">"capital_gain"</span><span class="p">,</span> <span class="s2">"capital_loss"</span><span class="p">,</span> <span class="s2">"hours_per_week"</span><span class="p">,</span> <span class="p">]</span> <span class="c1"># A dictionary of the categorical features and their vocabulary.</span> <span class="n">CATEGORICAL_FEATURES_WITH_VOCABULARY</span> <span class="o">=</span> <span class="p">{</span> <span class="s2">"workclass"</span><span class="p">:</span> <span class="nb">sorted</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">train_data</span><span class="p">[</span><span class="s2">"workclass"</span><span class="p">]</span><span class="o">.</span><span class="n">unique</span><span class="p">())),</span> <span class="s2">"education"</span><span class="p">:</span> <span class="nb">sorted</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">train_data</span><span class="p">[</span><span class="s2">"education"</span><span class="p">]</span><span class="o">.</span><span class="n">unique</span><span class="p">())),</span> <span class="s2">"marital_status"</span><span class="p">:</span> <span class="nb">sorted</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">train_data</span><span class="p">[</span><span class="s2">"marital_status"</span><span class="p">]</span><span class="o">.</span><span class="n">unique</span><span class="p">())),</span> <span class="s2">"occupation"</span><span class="p">:</span> <span class="nb">sorted</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">train_data</span><span class="p">[</span><span class="s2">"occupation"</span><span class="p">]</span><span class="o">.</span><span class="n">unique</span><span class="p">())),</span> <span class="s2">"relationship"</span><span class="p">:</span> <span class="nb">sorted</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">train_data</span><span class="p">[</span><span class="s2">"relationship"</span><span class="p">]</span><span class="o">.</span><span class="n">unique</span><span class="p">())),</span> <span class="s2">"race"</span><span class="p">:</span> <span class="nb">sorted</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">train_data</span><span class="p">[</span><span class="s2">"race"</span><span class="p">]</span><span class="o">.</span><span class="n">unique</span><span class="p">())),</span> <span class="s2">"gender"</span><span class="p">:</span> <span class="nb">sorted</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">train_data</span><span class="p">[</span><span class="s2">"gender"</span><span class="p">]</span><span class="o">.</span><span class="n">unique</span><span class="p">())),</span> <span class="s2">"native_country"</span><span class="p">:</span> <span class="nb">sorted</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">train_data</span><span class="p">[</span><span class="s2">"native_country"</span><span class="p">]</span><span class="o">.</span><span class="n">unique</span><span class="p">())),</span> <span class="p">}</span> <span class="c1"># A list of the columns to ignore from the dataset.</span> <span class="n">IGNORE_COLUMN_NAMES</span> <span class="o">=</span> <span class="p">[</span><span class="s2">"fnlwgt"</span><span class="p">]</span> <span class="c1"># A list of the categorical feature names.</span> <span class="n">CATEGORICAL_FEATURE_NAMES</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">CATEGORICAL_FEATURES_WITH_VOCABULARY</span><span class="o">.</span><span class="n">keys</span><span class="p">())</span> <span class="c1"># A list of all the input features.</span> <span class="n">FEATURE_NAMES</span> <span class="o">=</span> <span class="n">NUMERIC_FEATURE_NAMES</span> <span class="o">+</span> <span class="n">CATEGORICAL_FEATURE_NAMES</span> <span class="c1"># A list of column default values for each feature.</span> <span class="n">COLUMN_DEFAULTS</span> <span class="o">=</span> <span class="p">[</span> <span class="p">[</span><span class="mf">0.0</span><span class="p">]</span> <span class="k">if</span> <span class="n">feature_name</span> <span class="ow">in</span> <span class="n">NUMERIC_FEATURE_NAMES</span> <span class="o">+</span> <span class="n">IGNORE_COLUMN_NAMES</span> <span class="k">else</span> <span class="p">[</span><span class="s2">"NA"</span><span class="p">]</span> <span class="k">for</span> <span class="n">feature_name</span> <span class="ow">in</span> <span class="n">CSV_HEADER</span> <span class="p">]</span> <span class="c1"># The name of the target feature.</span> <span class="n">TARGET_FEATURE_NAME</span> <span class="o">=</span> <span class="s2">"income_bracket"</span> <span class="c1"># A list of the labels of the target features.</span> <span class="n">TARGET_LABELS</span> <span class="o">=</span> <span class="p">[</span><span class="s2">" <=50K"</span><span class="p">,</span> <span class="s2">" >50K"</span><span class="p">]</span> </code></pre></div> <hr /> <h2 id="create-tfdatadataset-objects-for-training-and-validation">Create <code>tf_data.Dataset</code> objects for training and validation</h2> <p>We create an input function to read and parse the file, and convert features and labels into a <a href="https://www.tensorflow.org/guide/datasets"><code>tf_data.Dataset</code></a> for training and validation. We also preprocess the input by mapping the target label to an index.</p> <div class="codehilite"><pre><span></span><code><span class="n">target_label_lookup</span> <span class="o">=</span> <span class="n">StringLookup</span><span class="p">(</span> <span class="n">vocabulary</span><span class="o">=</span><span class="n">TARGET_LABELS</span><span class="p">,</span> <span class="n">mask_token</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">num_oov_indices</span><span class="o">=</span><span class="mi">0</span> <span class="p">)</span> <span class="n">lookup_dict</span> <span class="o">=</span> <span class="p">{}</span> <span class="k">for</span> <span class="n">feature_name</span> <span class="ow">in</span> <span class="n">CATEGORICAL_FEATURE_NAMES</span><span class="p">:</span> <span class="n">vocabulary</span> <span class="o">=</span> <span class="n">CATEGORICAL_FEATURES_WITH_VOCABULARY</span><span class="p">[</span><span class="n">feature_name</span><span class="p">]</span> <span class="c1"># Create a lookup to convert a string values to an integer indices.</span> <span class="c1"># Since we are not using a mask token, nor expecting any out of vocabulary</span> <span class="c1"># (oov) token, we set mask_token to None and num_oov_indices to 0.</span> <span class="n">lookup</span> <span class="o">=</span> <span class="n">StringLookup</span><span class="p">(</span><span class="n">vocabulary</span><span class="o">=</span><span class="n">vocabulary</span><span class="p">,</span> <span class="n">mask_token</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">num_oov_indices</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="n">lookup_dict</span><span class="p">[</span><span class="n">feature_name</span><span class="p">]</span> <span class="o">=</span> <span class="n">lookup</span> <span class="k">def</span> <span class="nf">encode_categorical</span><span class="p">(</span><span class="n">batch_x</span><span class="p">,</span> <span class="n">batch_y</span><span class="p">):</span> <span class="k">for</span> <span class="n">feature_name</span> <span class="ow">in</span> <span class="n">CATEGORICAL_FEATURE_NAMES</span><span class="p">:</span> <span class="n">batch_x</span><span class="p">[</span><span class="n">feature_name</span><span class="p">]</span> <span class="o">=</span> <span class="n">lookup_dict</span><span class="p">[</span><span class="n">feature_name</span><span class="p">](</span><span class="n">batch_x</span><span class="p">[</span><span class="n">feature_name</span><span class="p">])</span> <span class="k">return</span> <span class="n">batch_x</span><span class="p">,</span> <span class="n">batch_y</span> <span class="k">def</span> <span class="nf">get_dataset_from_csv</span><span class="p">(</span><span class="n">csv_file_path</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">128</span><span class="p">):</span> <span class="n">dataset</span> <span class="o">=</span> <span class="p">(</span> <span class="n">tf_data</span><span class="o">.</span><span class="n">experimental</span><span class="o">.</span><span class="n">make_csv_dataset</span><span class="p">(</span> <span class="n">csv_file_path</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">column_names</span><span class="o">=</span><span class="n">CSV_HEADER</span><span class="p">,</span> <span class="n">column_defaults</span><span class="o">=</span><span class="n">COLUMN_DEFAULTS</span><span class="p">,</span> <span class="n">label_name</span><span class="o">=</span><span class="n">TARGET_FEATURE_NAME</span><span class="p">,</span> <span class="n">num_epochs</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">header</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">na_value</span><span class="o">=</span><span class="s2">"?"</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="n">shuffle</span><span class="p">,</span> <span class="p">)</span> <span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">features</span><span class="p">,</span> <span class="n">target</span><span class="p">:</span> <span class="p">(</span><span class="n">features</span><span class="p">,</span> <span class="n">target_label_lookup</span><span class="p">(</span><span class="n">target</span><span class="p">)))</span> <span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="n">encode_categorical</span><span class="p">)</span> <span class="p">)</span> <span class="k">return</span> <span class="n">dataset</span><span class="o">.</span><span class="n">cache</span><span class="p">()</span> </code></pre></div> <hr /> <h2 id="create-model-inputs">Create model inputs</h2> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">create_model_inputs</span><span class="p">():</span> <span class="n">inputs</span> <span class="o">=</span> <span class="p">{}</span> <span class="k">for</span> <span class="n">feature_name</span> <span class="ow">in</span> <span class="n">FEATURE_NAMES</span><span class="p">:</span> <span class="k">if</span> <span class="n">feature_name</span> <span class="ow">in</span> <span class="n">NUMERIC_FEATURE_NAMES</span><span class="p">:</span> <span class="n">inputs</span><span class="p">[</span><span class="n">feature_name</span><span class="p">]</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span> <span class="n">name</span><span class="o">=</span><span class="n">feature_name</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">(),</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"float32"</span> <span class="p">)</span> <span class="k">else</span><span class="p">:</span> <span class="n">inputs</span><span class="p">[</span><span class="n">feature_name</span><span class="p">]</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span> <span class="n">name</span><span class="o">=</span><span class="n">feature_name</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">(),</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"int32"</span> <span class="p">)</span> <span class="k">return</span> <span class="n">inputs</span> </code></pre></div> <hr /> <h2 id="encode-input-features">Encode input features</h2> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">encode_inputs</span><span class="p">(</span><span class="n">inputs</span><span class="p">):</span> <span class="n">encoded_features</span> <span class="o">=</span> <span class="p">[]</span> <span class="k">for</span> <span class="n">feature_name</span> <span class="ow">in</span> <span class="n">inputs</span><span class="p">:</span> <span class="k">if</span> <span class="n">feature_name</span> <span class="ow">in</span> <span class="n">CATEGORICAL_FEATURE_NAMES</span><span class="p">:</span> <span class="n">vocabulary</span> <span class="o">=</span> <span class="n">CATEGORICAL_FEATURES_WITH_VOCABULARY</span><span class="p">[</span><span class="n">feature_name</span><span class="p">]</span> <span class="c1"># Create a lookup to convert a string values to an integer indices.</span> <span class="c1"># Since we are not using a mask token, nor expecting any out of vocabulary</span> <span class="c1"># (oov) token, we set mask_token to None and num_oov_indices to 0.</span> <span class="n">value_index</span> <span class="o">=</span> <span class="n">inputs</span><span class="p">[</span><span class="n">feature_name</span><span class="p">]</span> <span class="n">embedding_dims</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">math</span><span class="o">.</span><span class="n">sqrt</span><span class="p">(</span><span class="n">lookup</span><span class="o">.</span><span class="n">vocabulary_size</span><span class="p">()))</span> <span class="c1"># Create an embedding layer with the specified dimensions.</span> <span class="n">embedding</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span> <span class="n">input_dim</span><span class="o">=</span><span class="n">lookup</span><span class="o">.</span><span class="n">vocabulary_size</span><span class="p">(),</span> <span class="n">output_dim</span><span class="o">=</span><span class="n">embedding_dims</span> <span class="p">)</span> <span class="c1"># Convert the index values to embedding representations.</span> <span class="n">encoded_feature</span> <span class="o">=</span> <span class="n">embedding</span><span class="p">(</span><span class="n">value_index</span><span class="p">)</span> <span class="k">else</span><span class="p">:</span> <span class="c1"># Use the numerical features as-is.</span> <span class="n">encoded_feature</span> <span class="o">=</span> <span class="n">inputs</span><span class="p">[</span><span class="n">feature_name</span><span class="p">]</span> <span class="k">if</span> <span class="n">inputs</span><span class="p">[</span><span class="n">feature_name</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span> <span class="n">encoded_feature</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">ops</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">encoded_feature</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">)</span> <span class="n">encoded_features</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">encoded_feature</span><span class="p">)</span> <span class="n">encoded_features</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">concatenate</span><span class="p">(</span><span class="n">encoded_features</span><span class="p">)</span> <span class="k">return</span> <span class="n">encoded_features</span> </code></pre></div> <hr /> <h2 id="deep-neural-decision-tree">Deep Neural Decision Tree</h2> <p>A neural decision tree model has two sets of weights to learn. The first set is <code>pi</code>, which represents the probability distribution of the classes in the tree leaves. The second set is the weights of the routing layer <code>decision_fn</code>, which represents the probability of going to each leave. The forward pass of the model works as follows:</p> <ol> <li>The model expects input <code>features</code> as a single vector encoding all the features of an instance in the batch. This vector can be generated from a Convolution Neural Network (CNN) applied to images or dense transformations applied to structured data features.</li> <li>The model first applies a <code>used_features_mask</code> to randomly select a subset of input features to use.</li> <li>Then, the model computes the probabilities (<code>mu</code>) for the input instances to reach the tree leaves by iteratively performing a <em>stochastic</em> routing throughout the tree levels.</li> <li>Finally, the probabilities of reaching the leaves are combined by the class probabilities at the leaves to produce the final <code>outputs</code>.</li> </ol> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">NeuralDecisionTree</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">depth</span><span class="p">,</span> <span class="n">num_features</span><span class="p">,</span> <span class="n">used_features_rate</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">depth</span> <span class="o">=</span> <span class="n">depth</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_leaves</span> <span class="o">=</span> <span class="mi">2</span><span class="o">**</span><span class="n">depth</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_classes</span> <span class="o">=</span> <span class="n">num_classes</span> <span class="c1"># Create a mask for the randomly selected features.</span> <span class="n">num_used_features</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">num_features</span> <span class="o">*</span> <span class="n">used_features_rate</span><span class="p">)</span> <span class="n">one_hot</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span><span class="n">num_features</span><span class="p">)</span> <span class="n">sampled_feature_indices</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">choice</span><span class="p">(</span> <span class="n">np</span><span class="o">.</span><span class="n">arange</span><span class="p">(</span><span class="n">num_features</span><span class="p">),</span> <span class="n">num_used_features</span><span class="p">,</span> <span class="n">replace</span><span class="o">=</span><span class="kc">False</span> <span class="p">)</span> <span class="bp">self</span><span class="o">.</span><span class="n">used_features_mask</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">convert_to_tensor</span><span class="p">(</span> <span class="n">one_hot</span><span class="p">[</span><span class="n">sampled_feature_indices</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"float32"</span> <span class="p">)</span> <span class="c1"># Initialize the weights of the classes in leaves.</span> <span class="bp">self</span><span class="o">.</span><span class="n">pi</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">add_weight</span><span class="p">(</span> <span class="n">initializer</span><span class="o">=</span><span class="s2">"random_normal"</span><span class="p">,</span> <span class="n">shape</span><span class="o">=</span><span class="p">[</span><span class="bp">self</span><span class="o">.</span><span class="n">num_leaves</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_classes</span><span class="p">],</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"float32"</span><span class="p">,</span> <span class="n">trainable</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="p">)</span> <span class="c1"># Initialize the stochastic routing layer.</span> <span class="bp">self</span><span class="o">.</span><span class="n">decision_fn</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span> <span class="n">units</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">num_leaves</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"sigmoid"</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">"decision"</span> <span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">features</span><span class="p">):</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">features</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span> <span class="c1"># Apply the feature mask to the input features.</span> <span class="n">features</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span> <span class="n">features</span><span class="p">,</span> <span class="n">ops</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">used_features_mask</span><span class="p">)</span> <span class="p">)</span> <span class="c1"># [batch_size, num_used_features]</span> <span class="c1"># Compute the routing probabilities.</span> <span class="n">decisions</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span> <span class="bp">self</span><span class="o">.</span><span class="n">decision_fn</span><span class="p">(</span><span class="n">features</span><span class="p">),</span> <span class="n">axis</span><span class="o">=</span><span class="mi">2</span> <span class="p">)</span> <span class="c1"># [batch_size, num_leaves, 1]</span> <span class="c1"># Concatenate the routing probabilities with their complements.</span> <span class="n">decisions</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">concatenate</span><span class="p">(</span> <span class="p">[</span><span class="n">decisions</span><span class="p">,</span> <span class="mi">1</span> <span class="o">-</span> <span class="n">decisions</span><span class="p">],</span> <span class="n">axis</span><span class="o">=</span><span class="mi">2</span> <span class="p">)</span> <span class="c1"># [batch_size, num_leaves, 2]</span> <span class="n">mu</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">ones</span><span class="p">([</span><span class="n">batch_size</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">])</span> <span class="n">begin_idx</span> <span class="o">=</span> <span class="mi">1</span> <span class="n">end_idx</span> <span class="o">=</span> <span class="mi">2</span> <span class="c1"># Traverse the tree in breadth-first order.</span> <span class="k">for</span> <span class="n">level</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">depth</span><span class="p">):</span> <span class="n">mu</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">mu</span><span class="p">,</span> <span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">])</span> <span class="c1"># [batch_size, 2 ** level, 1]</span> <span class="n">mu</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">tile</span><span class="p">(</span><span class="n">mu</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">))</span> <span class="c1"># [batch_size, 2 ** level, 2]</span> <span class="n">level_decisions</span> <span class="o">=</span> <span class="n">decisions</span><span class="p">[</span> <span class="p">:,</span> <span class="n">begin_idx</span><span class="p">:</span><span class="n">end_idx</span><span class="p">,</span> <span class="p">:</span> <span class="p">]</span> <span class="c1"># [batch_size, 2 ** level, 2]</span> <span class="n">mu</span> <span class="o">=</span> <span class="n">mu</span> <span class="o">*</span> <span class="n">level_decisions</span> <span class="c1"># [batch_size, 2**level, 2]</span> <span class="n">begin_idx</span> <span class="o">=</span> <span class="n">end_idx</span> <span class="n">end_idx</span> <span class="o">=</span> <span class="n">begin_idx</span> <span class="o">+</span> <span class="mi">2</span> <span class="o">**</span> <span class="p">(</span><span class="n">level</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="n">mu</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">mu</span><span class="p">,</span> <span class="p">[</span><span class="n">batch_size</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_leaves</span><span class="p">])</span> <span class="c1"># [batch_size, num_leaves]</span> <span class="n">probabilities</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">activations</span><span class="o">.</span><span class="n">softmax</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">pi</span><span class="p">)</span> <span class="c1"># [num_leaves, num_classes]</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">mu</span><span class="p">,</span> <span class="n">probabilities</span><span class="p">)</span> <span class="c1"># [batch_size, num_classes]</span> <span class="k">return</span> <span class="n">outputs</span> </code></pre></div> <hr /> <h2 id="deep-neural-decision-forest">Deep Neural Decision Forest</h2> <p>The neural decision forest model consists of a set of neural decision trees that are trained simultaneously. The output of the forest model is the average outputs of its trees.</p> <div class="codehilite"><pre><span></span><code><span class="k">class</span> <span class="nc">NeuralDecisionForest</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">):</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">num_trees</span><span class="p">,</span> <span class="n">depth</span><span class="p">,</span> <span class="n">num_features</span><span class="p">,</span> <span class="n">used_features_rate</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">):</span> <span class="nb">super</span><span class="p">()</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span> <span class="bp">self</span><span class="o">.</span><span class="n">ensemble</span> <span class="o">=</span> <span class="p">[]</span> <span class="c1"># Initialize the ensemble by adding NeuralDecisionTree instances.</span> <span class="c1"># Each tree will have its own randomly selected input features to use.</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_trees</span><span class="p">):</span> <span class="bp">self</span><span class="o">.</span><span class="n">ensemble</span><span class="o">.</span><span class="n">append</span><span class="p">(</span> <span class="n">NeuralDecisionTree</span><span class="p">(</span><span class="n">depth</span><span class="p">,</span> <span class="n">num_features</span><span class="p">,</span> <span class="n">used_features_rate</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">)</span> <span class="p">)</span> <span class="k">def</span> <span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">):</span> <span class="c1"># Initialize the outputs: a [batch_size, num_classes] matrix of zeros.</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">inputs</span><span class="p">)[</span><span class="mi">0</span><span class="p">]</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">ops</span><span class="o">.</span><span class="n">zeros</span><span class="p">([</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">])</span> <span class="c1"># Aggregate the outputs of trees in the ensemble.</span> <span class="k">for</span> <span class="n">tree</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">ensemble</span><span class="p">:</span> <span class="n">outputs</span> <span class="o">+=</span> <span class="n">tree</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="c1"># Divide the outputs by the ensemble size to get the average.</span> <span class="n">outputs</span> <span class="o">/=</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">ensemble</span><span class="p">)</span> <span class="k">return</span> <span class="n">outputs</span> </code></pre></div> <p>Finally, let's set up the code that will train and evaluate the model.</p> <div class="codehilite"><pre><span></span><code><span class="n">learning_rate</span> <span class="o">=</span> <span class="mf">0.01</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="mi">265</span> <span class="n">num_epochs</span> <span class="o">=</span> <span class="mi">10</span> <span class="k">def</span> <span class="nf">run_experiment</span><span class="p">(</span><span class="n">model</span><span class="p">):</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">learning_rate</span><span class="o">=</span><span class="n">learning_rate</span><span class="p">),</span> <span class="n">loss</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">SparseCategoricalCrossentropy</span><span class="p">(),</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">SparseCategoricalAccuracy</span><span class="p">()],</span> <span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Start training the model..."</span><span class="p">)</span> <span class="n">train_dataset</span> <span class="o">=</span> <span class="n">get_dataset_from_csv</span><span class="p">(</span> <span class="n">train_data_file</span><span class="p">,</span> <span class="n">shuffle</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span> <span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">train_dataset</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="n">num_epochs</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Model training finished"</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Evaluating the model on the test data..."</span><span class="p">)</span> <span class="n">test_dataset</span> <span class="o">=</span> <span class="n">get_dataset_from_csv</span><span class="p">(</span><span class="n">test_data_file</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">)</span> <span class="n">_</span><span class="p">,</span> <span class="n">accuracy</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">test_dataset</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Test accuracy: </span><span class="si">{</span><span class="nb">round</span><span class="p">(</span><span class="n">accuracy</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="mi">100</span><span class="p">,</span><span class="w"> </span><span class="mi">2</span><span class="p">)</span><span class="si">}</span><span class="s2">%"</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="experiment-1-train-a-decision-tree-model">Experiment 1: train a decision tree model</h2> <p>In this experiment, we train a single neural decision tree model where we use all input features.</p> <div class="codehilite"><pre><span></span><code><span class="n">num_trees</span> <span class="o">=</span> <span class="mi">10</span> <span class="n">depth</span> <span class="o">=</span> <span class="mi">10</span> <span class="n">used_features_rate</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="n">num_classes</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">TARGET_LABELS</span><span class="p">)</span> <span class="k">def</span> <span class="nf">create_tree_model</span><span class="p">():</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">create_model_inputs</span><span class="p">()</span> <span class="n">features</span> <span class="o">=</span> <span class="n">encode_inputs</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">features</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">BatchNormalization</span><span class="p">()(</span><span class="n">features</span><span class="p">)</span> <span class="n">num_features</span> <span class="o">=</span> <span class="n">features</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="n">tree</span> <span class="o">=</span> <span class="n">NeuralDecisionTree</span><span class="p">(</span><span class="n">depth</span><span class="p">,</span> <span class="n">num_features</span><span class="p">,</span> <span class="n">used_features_rate</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">)</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">tree</span><span class="p">(</span><span class="n">features</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="o">=</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="o">=</span><span class="n">outputs</span><span class="p">)</span> <span class="k">return</span> <span class="n">model</span> <span class="n">tree_model</span> <span class="o">=</span> <span class="n">create_tree_model</span><span class="p">()</span> <span class="n">run_experiment</span><span class="p">(</span><span class="n">tree_model</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Start training the model... Epoch 1/10 123/123 ━━━━━━━━━━━━━━━━━━━━ 5s 26ms/step - loss: 0.5308 - sparse_categorical_accuracy: 0.8150 Epoch 2/10 123/123 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 0.3476 - sparse_categorical_accuracy: 0.8429 Epoch 3/10 123/123 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 0.3312 - sparse_categorical_accuracy: 0.8478 Epoch 4/10 123/123 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 0.3247 - sparse_categorical_accuracy: 0.8495 Epoch 5/10 123/123 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 0.3202 - sparse_categorical_accuracy: 0.8512 Epoch 6/10 123/123 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 0.3158 - sparse_categorical_accuracy: 0.8536 Epoch 7/10 123/123 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 0.3116 - sparse_categorical_accuracy: 0.8572 Epoch 8/10 123/123 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 0.3071 - sparse_categorical_accuracy: 0.8608 Epoch 9/10 123/123 ━━━━━━━━━━━━━━━━━━━━ 1s 11ms/step - loss: 0.3026 - sparse_categorical_accuracy: 0.8630 Epoch 10/10 123/123 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 0.2975 - sparse_categorical_accuracy: 0.8653 Model training finished Evaluating the model on the test data... 62/62 ━━━━━━━━━━━━━━━━━━━━ 1s 13ms/step - loss: 0.3279 - sparse_categorical_accuracy: 0.8463 Test accuracy: 85.08% </code></pre></div> </div> <hr /> <h2 id="experiment-2-train-a-forest-model">Experiment 2: train a forest model</h2> <p>In this experiment, we train a neural decision forest with <code>num_trees</code> trees where each tree uses randomly selected 50% of the input features. You can control the number of features to be used in each tree by setting the <code>used_features_rate</code> variable. In addition, we set the depth to 5 instead of 10 compared to the previous experiment.</p> <div class="codehilite"><pre><span></span><code><span class="n">num_trees</span> <span class="o">=</span> <span class="mi">25</span> <span class="n">depth</span> <span class="o">=</span> <span class="mi">5</span> <span class="n">used_features_rate</span> <span class="o">=</span> <span class="mf">0.5</span> <span class="k">def</span> <span class="nf">create_forest_model</span><span class="p">():</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">create_model_inputs</span><span class="p">()</span> <span class="n">features</span> <span class="o">=</span> <span class="n">encode_inputs</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">features</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">BatchNormalization</span><span class="p">()(</span><span class="n">features</span><span class="p">)</span> <span class="n">num_features</span> <span class="o">=</span> <span class="n">features</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="n">forest_model</span> <span class="o">=</span> <span class="n">NeuralDecisionForest</span><span class="p">(</span> <span class="n">num_trees</span><span class="p">,</span> <span class="n">depth</span><span class="p">,</span> <span class="n">num_features</span><span class="p">,</span> <span class="n">used_features_rate</span><span class="p">,</span> <span class="n">num_classes</span> <span class="p">)</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">forest_model</span><span class="p">(</span><span class="n">features</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="o">=</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="o">=</span><span class="n">outputs</span><span class="p">)</span> <span class="k">return</span> <span class="n">model</span> <span class="n">forest_model</span> <span class="o">=</span> <span class="n">create_forest_model</span><span class="p">()</span> <span class="n">run_experiment</span><span class="p">(</span><span class="n">forest_model</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Start training the model... Epoch 1/10 123/123 ━━━━━━━━━━━━━━━━━━━━ 47s 202ms/step - loss: 0.5469 - sparse_categorical_accuracy: 0.7915 Epoch 2/10 123/123 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 0.3459 - sparse_categorical_accuracy: 0.8494 Epoch 3/10 123/123 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 0.3268 - sparse_categorical_accuracy: 0.8523 Epoch 4/10 123/123 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 0.3195 - sparse_categorical_accuracy: 0.8524 Epoch 5/10 123/123 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 0.3149 - sparse_categorical_accuracy: 0.8539 Epoch 6/10 123/123 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 0.3112 - sparse_categorical_accuracy: 0.8556 Epoch 7/10 123/123 ━━━━━━━━━━━━━━━━━━━━ 1s 10ms/step - loss: 0.3079 - sparse_categorical_accuracy: 0.8566 Epoch 8/10 123/123 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 0.3050 - sparse_categorical_accuracy: 0.8582 Epoch 9/10 123/123 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 0.3021 - sparse_categorical_accuracy: 0.8595 Epoch 10/10 123/123 ━━━━━━━━━━━━━━━━━━━━ 1s 9ms/step - loss: 0.2992 - sparse_categorical_accuracy: 0.8617 Model training finished Evaluating the model on the test data... 62/62 ━━━━━━━━━━━━━━━━━━━━ 5s 39ms/step - loss: 0.3145 - sparse_categorical_accuracy: 0.8503 Test accuracy: 85.55% </code></pre></div> </div> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#classification-with-neural-decision-forests'>Classification with Neural Decision Forests</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#the-dataset'>The dataset</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#setup'>Setup</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#prepare-the-data'>Prepare the data</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#define-dataset-metadata'>Define dataset metadata</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#create-tfdatadataset-objects-for-training-and-validation'>Create <code>tf_data.Dataset</code> objects for training and validation</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#create-model-inputs'>Create model inputs</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#encode-input-features'>Encode input features</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#deep-neural-decision-tree'>Deep Neural Decision Tree</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#deep-neural-decision-forest'>Deep Neural Decision Forest</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#experiment-1-train-a-decision-tree-model'>Experiment 1: train a decision tree model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#experiment-2-train-a-forest-model'>Experiment 2: train a forest model</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>