CINXE.COM
Structured data classification from scratch
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/structured_data/structured_data_classification_from_scratch/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Structured data classification from scratch"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Structured data classification from scratch"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Structured data classification from scratch</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink active" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink2" href="/examples/structured_data/structured_data_classification_with_feature_space/">Structured data classification with FeatureSpace</a> <a class="nav-sublink2" href="/examples/structured_data/feature_space_advanced/">FeatureSpace advanced use cases</a> <a class="nav-sublink2" href="/examples/structured_data/imbalanced_classification/">Imbalanced classification: credit card fraud detection</a> <a class="nav-sublink2 active" href="/examples/structured_data/structured_data_classification_from_scratch/">Structured data classification from scratch</a> <a class="nav-sublink2" href="/examples/structured_data/wide_deep_cross_networks/">Structured data learning with Wide, Deep, and Cross networks</a> <a class="nav-sublink2" href="/examples/structured_data/classification_with_grn_and_vsn/">Classification with Gated Residual and Variable Selection Networks</a> <a class="nav-sublink2" href="/examples/structured_data/classification_with_tfdf/">Classification with TensorFlow Decision Forests</a> <a class="nav-sublink2" href="/examples/structured_data/deep_neural_decision_forests/">Classification with Neural Decision Forests</a> <a class="nav-sublink2" href="/examples/structured_data/tabtransformer/">Structured data learning with TabTransformer</a> <a class="nav-sublink2" href="/examples/structured_data/collaborative_filtering_movielens/">Collaborative Filtering for Movie Recommendations</a> <a class="nav-sublink2" href="/examples/structured_data/movielens_recommendations_transformers/">A Transformer-based recommendation system</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparam Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/structured_data/'>Structured Data</a> / Structured data classification from scratch </div> <div class='k-content'> <h1 id="structured-data-classification-from-scratch">Structured data classification from scratch</h1> <p><strong>Author:</strong> <a href="https://twitter.com/fchollet">fchollet</a><br> <strong>Date created:</strong> 2020/06/09<br> <strong>Last modified:</strong> 2020/06/09<br> <strong>Description:</strong> Binary classification of structured data including numerical and categorical features.</p> <div class='example_version_banner keras_3'>ⓘ This example uses Keras 3</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/structured_data/ipynb/structured_data_classification_from_scratch.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/examples/structured_data/structured_data_classification_from_scratch.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="introduction">Introduction</h2> <p>This example demonstrates how to do structured data classification, starting from a raw CSV file. Our data includes both numerical and categorical features. We will use Keras preprocessing layers to normalize the numerical features and vectorize the categorical ones.</p> <p>Note that this example should be run with TensorFlow 2.5 or higher.</p> <h3 id="the-dataset">The dataset</h3> <p><a href="https://archive.ics.uci.edu/ml/datasets/heart+Disease">Our dataset</a> is provided by the Cleveland Clinic Foundation for Heart Disease. It's a CSV file with 303 rows. Each row contains information about a patient (a <strong>sample</strong>), and each column describes an attribute of the patient (a <strong>feature</strong>). We use the features to predict whether a patient has a heart disease (<strong>binary classification</strong>).</p> <p>Here's the description of each feature:</p> <table> <thead> <tr> <th>Column</th> <th>Description</th> <th>Feature Type</th> </tr> </thead> <tbody> <tr> <td>Age</td> <td>Age in years</td> <td>Numerical</td> </tr> <tr> <td>Sex</td> <td>(1 = male; 0 = female)</td> <td>Categorical</td> </tr> <tr> <td>CP</td> <td>Chest pain type (0, 1, 2, 3, 4)</td> <td>Categorical</td> </tr> <tr> <td>Trestbpd</td> <td>Resting blood pressure (in mm Hg on admission)</td> <td>Numerical</td> </tr> <tr> <td>Chol</td> <td>Serum cholesterol in mg/dl</td> <td>Numerical</td> </tr> <tr> <td>FBS</td> <td>fasting blood sugar in 120 mg/dl (1 = true; 0 = false)</td> <td>Categorical</td> </tr> <tr> <td>RestECG</td> <td>Resting electrocardiogram results (0, 1, 2)</td> <td>Categorical</td> </tr> <tr> <td>Thalach</td> <td>Maximum heart rate achieved</td> <td>Numerical</td> </tr> <tr> <td>Exang</td> <td>Exercise induced angina (1 = yes; 0 = no)</td> <td>Categorical</td> </tr> <tr> <td>Oldpeak</td> <td>ST depression induced by exercise relative to rest</td> <td>Numerical</td> </tr> <tr> <td>Slope</td> <td>Slope of the peak exercise ST segment</td> <td>Numerical</td> </tr> <tr> <td>CA</td> <td>Number of major vessels (0-3) colored by fluoroscopy</td> <td>Both numerical & categorical</td> </tr> <tr> <td>Thal</td> <td>3 = normal; 6 = fixed defect; 7 = reversible defect</td> <td>Categorical</td> </tr> <tr> <td>Target</td> <td>Diagnosis of heart disease (1 = true; 0 = false)</td> <td>Target</td> </tr> </tbody> </table> <hr /> <h2 id="setup">Setup</h2> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">os</span> <span class="c1"># TensorFlow is the only backend that supports string inputs.</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s2">"KERAS_BACKEND"</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"tensorflow"</span> <span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="nn">tf</span> <span class="kn">import</span> <span class="nn">pandas</span> <span class="k">as</span> <span class="nn">pd</span> <span class="kn">import</span> <span class="nn">keras</span> <span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">layers</span> </code></pre></div> <hr /> <h2 id="preparing-the-data">Preparing the data</h2> <p>Let's download the data and load it into a Pandas dataframe:</p> <div class="codehilite"><pre><span></span><code><span class="n">file_url</span> <span class="o">=</span> <span class="s2">"http://storage.googleapis.com/download.tensorflow.org/data/heart.csv"</span> <span class="n">dataframe</span> <span class="o">=</span> <span class="n">pd</span><span class="o">.</span><span class="n">read_csv</span><span class="p">(</span><span class="n">file_url</span><span class="p">)</span> </code></pre></div> <p>The dataset includes 303 samples with 14 columns per sample (13 features, plus the target label):</p> <div class="codehilite"><pre><span></span><code><span class="n">dataframe</span><span class="o">.</span><span class="n">shape</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>(303, 14) </code></pre></div> </div> <p>Here's a preview of a few samples:</p> <div class="codehilite"><pre><span></span><code><span class="n">dataframe</span><span class="o">.</span><span class="n">head</span><span class="p">()</span> </code></pre></div> <div> <style scoped> .dataframe tbody tr th:only-of-type { vertical-align: middle; } <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>.dataframe tbody tr th { vertical-align: top; } .dataframe thead th { text-align: right; } </code></pre></div> </div> </style> <table border="1" class="dataframe"> <thead> <tr style="text-align: right;"> <th></th> <th>age</th> <th>sex</th> <th>cp</th> <th>trestbps</th> <th>chol</th> <th>fbs</th> <th>restecg</th> <th>thalach</th> <th>exang</th> <th>oldpeak</th> <th>slope</th> <th>ca</th> <th>thal</th> <th>target</th> </tr> </thead> <tbody> <tr> <th>0</th> <td>63</td> <td>1</td> <td>1</td> <td>145</td> <td>233</td> <td>1</td> <td>2</td> <td>150</td> <td>0</td> <td>2.3</td> <td>3</td> <td>0</td> <td>fixed</td> <td>0</td> </tr> <tr> <th>1</th> <td>67</td> <td>1</td> <td>4</td> <td>160</td> <td>286</td> <td>0</td> <td>2</td> <td>108</td> <td>1</td> <td>1.5</td> <td>2</td> <td>3</td> <td>normal</td> <td>1</td> </tr> <tr> <th>2</th> <td>67</td> <td>1</td> <td>4</td> <td>120</td> <td>229</td> <td>0</td> <td>2</td> <td>129</td> <td>1</td> <td>2.6</td> <td>2</td> <td>2</td> <td>reversible</td> <td>0</td> </tr> <tr> <th>3</th> <td>37</td> <td>1</td> <td>3</td> <td>130</td> <td>250</td> <td>0</td> <td>0</td> <td>187</td> <td>0</td> <td>3.5</td> <td>3</td> <td>0</td> <td>normal</td> <td>0</td> </tr> <tr> <th>4</th> <td>41</td> <td>0</td> <td>2</td> <td>130</td> <td>204</td> <td>0</td> <td>2</td> <td>172</td> <td>0</td> <td>1.4</td> <td>1</td> <td>0</td> <td>normal</td> <td>0</td> </tr> </tbody> </table> </div> <p>The last column, "target", indicates whether the patient has a heart disease (1) or not (0).</p> <p>Let's split the data into a training and validation set:</p> <div class="codehilite"><pre><span></span><code><span class="n">val_dataframe</span> <span class="o">=</span> <span class="n">dataframe</span><span class="o">.</span><span class="n">sample</span><span class="p">(</span><span class="n">frac</span><span class="o">=</span><span class="mf">0.2</span><span class="p">,</span> <span class="n">random_state</span><span class="o">=</span><span class="mi">1337</span><span class="p">)</span> <span class="n">train_dataframe</span> <span class="o">=</span> <span class="n">dataframe</span><span class="o">.</span><span class="n">drop</span><span class="p">(</span><span class="n">val_dataframe</span><span class="o">.</span><span class="n">index</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span> <span class="sa">f</span><span class="s2">"Using </span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">train_dataframe</span><span class="p">)</span><span class="si">}</span><span class="s2"> samples for training "</span> <span class="sa">f</span><span class="s2">"and </span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">val_dataframe</span><span class="p">)</span><span class="si">}</span><span class="s2"> for validation"</span> <span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Using 242 samples for training and 61 for validation </code></pre></div> </div> <p>Let's generate <a href="https://www.tensorflow.org/api_docs/python/tf/data/Dataset"><code>tf.data.Dataset</code></a> objects for each dataframe:</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">dataframe_to_dataset</span><span class="p">(</span><span class="n">dataframe</span><span class="p">):</span> <span class="n">dataframe</span> <span class="o">=</span> <span class="n">dataframe</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">dataframe</span><span class="o">.</span><span class="n">pop</span><span class="p">(</span><span class="s2">"target"</span><span class="p">)</span> <span class="n">ds</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">from_tensor_slices</span><span class="p">((</span><span class="nb">dict</span><span class="p">(</span><span class="n">dataframe</span><span class="p">),</span> <span class="n">labels</span><span class="p">))</span> <span class="n">ds</span> <span class="o">=</span> <span class="n">ds</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="n">buffer_size</span><span class="o">=</span><span class="nb">len</span><span class="p">(</span><span class="n">dataframe</span><span class="p">))</span> <span class="k">return</span> <span class="n">ds</span> <span class="n">train_ds</span> <span class="o">=</span> <span class="n">dataframe_to_dataset</span><span class="p">(</span><span class="n">train_dataframe</span><span class="p">)</span> <span class="n">val_ds</span> <span class="o">=</span> <span class="n">dataframe_to_dataset</span><span class="p">(</span><span class="n">val_dataframe</span><span class="p">)</span> </code></pre></div> <p>Each <code>Dataset</code> yields a tuple <code>(input, target)</code> where <code>input</code> is a dictionary of features and <code>target</code> is the value <code>0</code> or <code>1</code>:</p> <div class="codehilite"><pre><span></span><code><span class="k">for</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span> <span class="ow">in</span> <span class="n">train_ds</span><span class="o">.</span><span class="n">take</span><span class="p">(</span><span class="mi">1</span><span class="p">):</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Input:"</span><span class="p">,</span> <span class="n">x</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Target:"</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Input: {'age': <tf.Tensor: shape=(), dtype=int64, numpy=64>, 'sex': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'cp': <tf.Tensor: shape=(), dtype=int64, numpy=4>, 'trestbps': <tf.Tensor: shape=(), dtype=int64, numpy=128>, 'chol': <tf.Tensor: shape=(), dtype=int64, numpy=263>, 'fbs': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'restecg': <tf.Tensor: shape=(), dtype=int64, numpy=0>, 'thalach': <tf.Tensor: shape=(), dtype=int64, numpy=105>, 'exang': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'oldpeak': <tf.Tensor: shape=(), dtype=float64, numpy=0.2>, 'slope': <tf.Tensor: shape=(), dtype=int64, numpy=2>, 'ca': <tf.Tensor: shape=(), dtype=int64, numpy=1>, 'thal': <tf.Tensor: shape=(), dtype=string, numpy=b'reversible'>} Target: tf.Tensor(0, shape=(), dtype=int64) </code></pre></div> </div> <p>Let's batch the datasets:</p> <div class="codehilite"><pre><span></span><code><span class="n">train_ds</span> <span class="o">=</span> <span class="n">train_ds</span><span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="mi">32</span><span class="p">)</span> <span class="n">val_ds</span> <span class="o">=</span> <span class="n">val_ds</span><span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="mi">32</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="feature-preprocessing-with-keras-layers">Feature preprocessing with Keras layers</h2> <p>The following features are categorical features encoded as integers:</p> <ul> <li><code>sex</code></li> <li><code>cp</code></li> <li><code>fbs</code></li> <li><code>restecg</code></li> <li><code>exang</code></li> <li><code>ca</code></li> </ul> <p>We will encode these features using <strong>one-hot encoding</strong>. We have two options here:</p> <ul> <li>Use <code>CategoryEncoding()</code>, which requires knowing the range of input values and will error on input outside the range.</li> <li>Use <code>IntegerLookup()</code> which will build a lookup table for inputs and reserve an output index for unkown input values.</li> </ul> <p>For this example, we want a simple solution that will handle out of range inputs at inference, so we will use <code>IntegerLookup()</code>.</p> <p>We also have a categorical feature encoded as a string: <code>thal</code>. We will create an index of all possible features and encode output using the <code>StringLookup()</code> layer.</p> <p>Finally, the following feature are continuous numerical features:</p> <ul> <li><code>age</code></li> <li><code>trestbps</code></li> <li><code>chol</code></li> <li><code>thalach</code></li> <li><code>oldpeak</code></li> <li><code>slope</code></li> </ul> <p>For each of these features, we will use a <code>Normalization()</code> layer to make sure the mean of each feature is 0 and its standard deviation is 1.</p> <p>Below, we define 2 utility functions to do the operations:</p> <ul> <li><code>encode_numerical_feature</code> to apply featurewise normalization to numerical features.</li> <li><code>encode_categorical_feature</code> to one-hot encode string or integer categorical features.</li> </ul> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">encode_numerical_feature</span><span class="p">(</span><span class="n">feature</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">dataset</span><span class="p">):</span> <span class="c1"># Create a Normalization layer for our feature</span> <span class="n">normalizer</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Normalization</span><span class="p">()</span> <span class="c1"># Prepare a Dataset that only yields our feature</span> <span class="n">feature_ds</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="n">x</span><span class="p">[</span><span class="n">name</span><span class="p">])</span> <span class="n">feature_ds</span> <span class="o">=</span> <span class="n">feature_ds</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">tf</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">))</span> <span class="c1"># Learn the statistics of the data</span> <span class="n">normalizer</span><span class="o">.</span><span class="n">adapt</span><span class="p">(</span><span class="n">feature_ds</span><span class="p">)</span> <span class="c1"># Normalize the input feature</span> <span class="n">encoded_feature</span> <span class="o">=</span> <span class="n">normalizer</span><span class="p">(</span><span class="n">feature</span><span class="p">)</span> <span class="k">return</span> <span class="n">encoded_feature</span> <span class="k">def</span> <span class="nf">encode_categorical_feature</span><span class="p">(</span><span class="n">feature</span><span class="p">,</span> <span class="n">name</span><span class="p">,</span> <span class="n">dataset</span><span class="p">,</span> <span class="n">is_string</span><span class="p">):</span> <span class="n">lookup_class</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">StringLookup</span> <span class="k">if</span> <span class="n">is_string</span> <span class="k">else</span> <span class="n">layers</span><span class="o">.</span><span class="n">IntegerLookup</span> <span class="c1"># Create a lookup layer which will turn strings into integer indices</span> <span class="n">lookup</span> <span class="o">=</span> <span class="n">lookup_class</span><span class="p">(</span><span class="n">output_mode</span><span class="o">=</span><span class="s2">"binary"</span><span class="p">)</span> <span class="c1"># Prepare a Dataset that only yields our feature</span> <span class="n">feature_ds</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">:</span> <span class="n">x</span><span class="p">[</span><span class="n">name</span><span class="p">])</span> <span class="n">feature_ds</span> <span class="o">=</span> <span class="n">feature_ds</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">tf</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="o">-</span><span class="mi">1</span><span class="p">))</span> <span class="c1"># Learn the set of possible string values and assign them a fixed integer index</span> <span class="n">lookup</span><span class="o">.</span><span class="n">adapt</span><span class="p">(</span><span class="n">feature_ds</span><span class="p">)</span> <span class="c1"># Turn the string input into integer indices</span> <span class="n">encoded_feature</span> <span class="o">=</span> <span class="n">lookup</span><span class="p">(</span><span class="n">feature</span><span class="p">)</span> <span class="k">return</span> <span class="n">encoded_feature</span> </code></pre></div> <hr /> <h2 id="build-a-model">Build a model</h2> <p>With this done, we can create our end-to-end model:</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Categorical features encoded as integers</span> <span class="n">sex</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"sex"</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"int64"</span><span class="p">)</span> <span class="n">cp</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"cp"</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"int64"</span><span class="p">)</span> <span class="n">fbs</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"fbs"</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"int64"</span><span class="p">)</span> <span class="n">restecg</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"restecg"</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"int64"</span><span class="p">)</span> <span class="n">exang</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"exang"</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"int64"</span><span class="p">)</span> <span class="n">ca</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"ca"</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"int64"</span><span class="p">)</span> <span class="c1"># Categorical feature encoded as string</span> <span class="n">thal</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"thal"</span><span class="p">,</span> <span class="n">dtype</span><span class="o">=</span><span class="s2">"string"</span><span class="p">)</span> <span class="c1"># Numerical features</span> <span class="n">age</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"age"</span><span class="p">)</span> <span class="n">trestbps</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"trestbps"</span><span class="p">)</span> <span class="n">chol</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"chol"</span><span class="p">)</span> <span class="n">thalach</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"thalach"</span><span class="p">)</span> <span class="n">oldpeak</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"oldpeak"</span><span class="p">)</span> <span class="n">slope</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">1</span><span class="p">,),</span> <span class="n">name</span><span class="o">=</span><span class="s2">"slope"</span><span class="p">)</span> <span class="n">all_inputs</span> <span class="o">=</span> <span class="p">[</span> <span class="n">sex</span><span class="p">,</span> <span class="n">cp</span><span class="p">,</span> <span class="n">fbs</span><span class="p">,</span> <span class="n">restecg</span><span class="p">,</span> <span class="n">exang</span><span class="p">,</span> <span class="n">ca</span><span class="p">,</span> <span class="n">thal</span><span class="p">,</span> <span class="n">age</span><span class="p">,</span> <span class="n">trestbps</span><span class="p">,</span> <span class="n">chol</span><span class="p">,</span> <span class="n">thalach</span><span class="p">,</span> <span class="n">oldpeak</span><span class="p">,</span> <span class="n">slope</span><span class="p">,</span> <span class="p">]</span> <span class="c1"># Integer categorical features</span> <span class="n">sex_encoded</span> <span class="o">=</span> <span class="n">encode_categorical_feature</span><span class="p">(</span><span class="n">sex</span><span class="p">,</span> <span class="s2">"sex"</span><span class="p">,</span> <span class="n">train_ds</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span> <span class="n">cp_encoded</span> <span class="o">=</span> <span class="n">encode_categorical_feature</span><span class="p">(</span><span class="n">cp</span><span class="p">,</span> <span class="s2">"cp"</span><span class="p">,</span> <span class="n">train_ds</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span> <span class="n">fbs_encoded</span> <span class="o">=</span> <span class="n">encode_categorical_feature</span><span class="p">(</span><span class="n">fbs</span><span class="p">,</span> <span class="s2">"fbs"</span><span class="p">,</span> <span class="n">train_ds</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span> <span class="n">restecg_encoded</span> <span class="o">=</span> <span class="n">encode_categorical_feature</span><span class="p">(</span><span class="n">restecg</span><span class="p">,</span> <span class="s2">"restecg"</span><span class="p">,</span> <span class="n">train_ds</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span> <span class="n">exang_encoded</span> <span class="o">=</span> <span class="n">encode_categorical_feature</span><span class="p">(</span><span class="n">exang</span><span class="p">,</span> <span class="s2">"exang"</span><span class="p">,</span> <span class="n">train_ds</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span> <span class="n">ca_encoded</span> <span class="o">=</span> <span class="n">encode_categorical_feature</span><span class="p">(</span><span class="n">ca</span><span class="p">,</span> <span class="s2">"ca"</span><span class="p">,</span> <span class="n">train_ds</span><span class="p">,</span> <span class="kc">False</span><span class="p">)</span> <span class="c1"># String categorical features</span> <span class="n">thal_encoded</span> <span class="o">=</span> <span class="n">encode_categorical_feature</span><span class="p">(</span><span class="n">thal</span><span class="p">,</span> <span class="s2">"thal"</span><span class="p">,</span> <span class="n">train_ds</span><span class="p">,</span> <span class="kc">True</span><span class="p">)</span> <span class="c1"># Numerical features</span> <span class="n">age_encoded</span> <span class="o">=</span> <span class="n">encode_numerical_feature</span><span class="p">(</span><span class="n">age</span><span class="p">,</span> <span class="s2">"age"</span><span class="p">,</span> <span class="n">train_ds</span><span class="p">)</span> <span class="n">trestbps_encoded</span> <span class="o">=</span> <span class="n">encode_numerical_feature</span><span class="p">(</span><span class="n">trestbps</span><span class="p">,</span> <span class="s2">"trestbps"</span><span class="p">,</span> <span class="n">train_ds</span><span class="p">)</span> <span class="n">chol_encoded</span> <span class="o">=</span> <span class="n">encode_numerical_feature</span><span class="p">(</span><span class="n">chol</span><span class="p">,</span> <span class="s2">"chol"</span><span class="p">,</span> <span class="n">train_ds</span><span class="p">)</span> <span class="n">thalach_encoded</span> <span class="o">=</span> <span class="n">encode_numerical_feature</span><span class="p">(</span><span class="n">thalach</span><span class="p">,</span> <span class="s2">"thalach"</span><span class="p">,</span> <span class="n">train_ds</span><span class="p">)</span> <span class="n">oldpeak_encoded</span> <span class="o">=</span> <span class="n">encode_numerical_feature</span><span class="p">(</span><span class="n">oldpeak</span><span class="p">,</span> <span class="s2">"oldpeak"</span><span class="p">,</span> <span class="n">train_ds</span><span class="p">)</span> <span class="n">slope_encoded</span> <span class="o">=</span> <span class="n">encode_numerical_feature</span><span class="p">(</span><span class="n">slope</span><span class="p">,</span> <span class="s2">"slope"</span><span class="p">,</span> <span class="n">train_ds</span><span class="p">)</span> <span class="n">all_features</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">concatenate</span><span class="p">(</span> <span class="p">[</span> <span class="n">sex_encoded</span><span class="p">,</span> <span class="n">cp_encoded</span><span class="p">,</span> <span class="n">fbs_encoded</span><span class="p">,</span> <span class="n">restecg_encoded</span><span class="p">,</span> <span class="n">exang_encoded</span><span class="p">,</span> <span class="n">slope_encoded</span><span class="p">,</span> <span class="n">ca_encoded</span><span class="p">,</span> <span class="n">thal_encoded</span><span class="p">,</span> <span class="n">age_encoded</span><span class="p">,</span> <span class="n">trestbps_encoded</span><span class="p">,</span> <span class="n">chol_encoded</span><span class="p">,</span> <span class="n">thalach_encoded</span><span class="p">,</span> <span class="n">oldpeak_encoded</span><span class="p">,</span> <span class="p">]</span> <span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"relu"</span><span class="p">)(</span><span class="n">all_features</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="mf">0.5</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">output</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"sigmoid"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">all_inputs</span><span class="p">,</span> <span class="n">output</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="s2">"adam"</span><span class="p">,</span> <span class="s2">"binary_crossentropy"</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="s2">"accuracy"</span><span class="p">])</span> </code></pre></div> <p>Let's visualize our connectivity graph:</p> <div class="codehilite"><pre><span></span><code><span class="c1"># `rankdir='LR'` is to make the graph horizontal.</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">plot_model</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">show_shapes</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">rankdir</span><span class="o">=</span><span class="s2">"LR"</span><span class="p">)</span> </code></pre></div> <p><img alt="png" src="/img/examples/structured_data/structured_data_classification_from_scratch/structured_data_classification_from_scratch_23_0.png" /></p> <hr /> <h2 id="train-the-model">Train the model</h2> <div class="codehilite"><pre><span></span><code><span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">train_ds</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">50</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="n">val_ds</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 1/50 8/8 ━━━━━━━━━━━━━━━━━━━━ 5s 46ms/step - accuracy: 0.3932 - loss: 0.8749 - val_accuracy: 0.3303 - val_loss: 0.7814 Epoch 2/50 8/8 ━━━━━━━━━━━━━━━━━━━━ 1s 7ms/step - accuracy: 0.4262 - loss: 0.8375 - val_accuracy: 0.4914 - val_loss: 0.6980 Epoch 3/50 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - accuracy: 0.4835 - loss: 0.7350 - val_accuracy: 0.6541 - val_loss: 0.6320 Epoch 4/50 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.5932 - loss: 0.6665 - val_accuracy: 0.7543 - val_loss: 0.5743 Epoch 5/50 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.5861 - loss: 0.6600 - val_accuracy: 0.7683 - val_loss: 0.5360 Epoch 6/50 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.6489 - loss: 0.6020 - val_accuracy: 0.7748 - val_loss: 0.4998 Epoch 7/50 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.6880 - loss: 0.5668 - val_accuracy: 0.7699 - val_loss: 0.4800 Epoch 8/50 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.7572 - loss: 0.5009 - val_accuracy: 0.7559 - val_loss: 0.4573 Epoch 9/50 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.7492 - loss: 0.5192 - val_accuracy: 0.8060 - val_loss: 0.4414 Epoch 10/50 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - accuracy: 0.7212 - loss: 0.4973 - val_accuracy: 0.8077 - val_loss: 0.4259 Epoch 11/50 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.7616 - loss: 0.4704 - val_accuracy: 0.7904 - val_loss: 0.4143 Epoch 12/50 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8374 - loss: 0.4342 - val_accuracy: 0.7872 - val_loss: 0.4061 Epoch 13/50 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.7863 - loss: 0.4630 - val_accuracy: 0.7888 - val_loss: 0.3980 Epoch 14/50 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.7742 - loss: 0.4492 - val_accuracy: 0.7996 - val_loss: 0.3998 Epoch 15/50 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8083 - loss: 0.4280 - val_accuracy: 0.8060 - val_loss: 0.3855 Epoch 16/50 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8058 - loss: 0.4191 - val_accuracy: 0.8217 - val_loss: 0.3819 Epoch 17/50 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8071 - loss: 0.4111 - val_accuracy: 0.8389 - val_loss: 0.3763 Epoch 18/50 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - accuracy: 0.8533 - loss: 0.3676 - val_accuracy: 0.8373 - val_loss: 0.3792 Epoch 19/50 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8170 - loss: 0.3850 - val_accuracy: 0.8357 - val_loss: 0.3744 Epoch 20/50 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8207 - loss: 0.3767 - val_accuracy: 0.8168 - val_loss: 0.3759 Epoch 21/50 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8151 - loss: 0.3596 - val_accuracy: 0.8217 - val_loss: 0.3685 Epoch 22/50 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.7988 - loss: 0.4087 - val_accuracy: 0.8184 - val_loss: 0.3701 Epoch 23/50 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8180 - loss: 0.3632 - val_accuracy: 0.8217 - val_loss: 0.3614 Epoch 24/50 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8295 - loss: 0.3504 - val_accuracy: 0.8200 - val_loss: 0.3683 Epoch 25/50 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8386 - loss: 0.3864 - val_accuracy: 0.8200 - val_loss: 0.3655 Epoch 26/50 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8482 - loss: 0.3345 - val_accuracy: 0.8044 - val_loss: 0.3639 Epoch 27/50 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - accuracy: 0.8340 - loss: 0.3470 - val_accuracy: 0.8077 - val_loss: 0.3616 Epoch 28/50 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8418 - loss: 0.3684 - val_accuracy: 0.8060 - val_loss: 0.3629 Epoch 29/50 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8309 - loss: 0.3147 - val_accuracy: 0.8060 - val_loss: 0.3637 Epoch 30/50 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8722 - loss: 0.3151 - val_accuracy: 0.8044 - val_loss: 0.3672 Epoch 31/50 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - accuracy: 0.8746 - loss: 0.3043 - val_accuracy: 0.8060 - val_loss: 0.3637 Epoch 32/50 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8794 - loss: 0.3245 - val_accuracy: 0.8200 - val_loss: 0.3685 Epoch 33/50 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - accuracy: 0.8644 - loss: 0.3541 - val_accuracy: 0.8357 - val_loss: 0.3714 Epoch 34/50 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8867 - loss: 0.3007 - val_accuracy: 0.8373 - val_loss: 0.3680 Epoch 35/50 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8737 - loss: 0.3168 - val_accuracy: 0.8357 - val_loss: 0.3695 Epoch 36/50 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8191 - loss: 0.3298 - val_accuracy: 0.8357 - val_loss: 0.3736 Epoch 37/50 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8613 - loss: 0.3543 - val_accuracy: 0.8357 - val_loss: 0.3745 Epoch 38/50 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8835 - loss: 0.2835 - val_accuracy: 0.8357 - val_loss: 0.3707 Epoch 39/50 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8784 - loss: 0.2893 - val_accuracy: 0.8357 - val_loss: 0.3716 Epoch 40/50 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8919 - loss: 0.2587 - val_accuracy: 0.8168 - val_loss: 0.3770 Epoch 41/50 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8882 - loss: 0.2660 - val_accuracy: 0.8217 - val_loss: 0.3674 Epoch 42/50 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8790 - loss: 0.2931 - val_accuracy: 0.8200 - val_loss: 0.3723 Epoch 43/50 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8851 - loss: 0.2892 - val_accuracy: 0.8200 - val_loss: 0.3733 Epoch 44/50 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8504 - loss: 0.3189 - val_accuracy: 0.8200 - val_loss: 0.3755 Epoch 45/50 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8610 - loss: 0.3116 - val_accuracy: 0.8184 - val_loss: 0.3788 Epoch 46/50 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - accuracy: 0.8956 - loss: 0.2544 - val_accuracy: 0.8184 - val_loss: 0.3738 Epoch 47/50 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.9080 - loss: 0.2895 - val_accuracy: 0.8217 - val_loss: 0.3750 Epoch 48/50 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8706 - loss: 0.2993 - val_accuracy: 0.8217 - val_loss: 0.3757 Epoch 49/50 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 7ms/step - accuracy: 0.8724 - loss: 0.2979 - val_accuracy: 0.8184 - val_loss: 0.3781 Epoch 50/50 8/8 ━━━━━━━━━━━━━━━━━━━━ 0s 8ms/step - accuracy: 0.8609 - loss: 0.2937 - val_accuracy: 0.8217 - val_loss: 0.3791 <keras.src.callbacks.history.History at 0x7efc32e01780> </code></pre></div> </div> <p>We quickly get to 80% validation accuracy.</p> <hr /> <h2 id="inference-on-new-data">Inference on new data</h2> <p>To get a prediction for a new sample, you can simply call <code>model.predict()</code>. There are just two things you need to do:</p> <ol> <li>wrap scalars into a list so as to have a batch dimension (models only process batches of data, not single samples)</li> <li>Call <code>convert_to_tensor</code> on each feature</li> </ol> <div class="codehilite"><pre><span></span><code><span class="n">sample</span> <span class="o">=</span> <span class="p">{</span> <span class="s2">"age"</span><span class="p">:</span> <span class="mi">60</span><span class="p">,</span> <span class="s2">"sex"</span><span class="p">:</span> <span class="mi">1</span><span class="p">,</span> <span class="s2">"cp"</span><span class="p">:</span> <span class="mi">1</span><span class="p">,</span> <span class="s2">"trestbps"</span><span class="p">:</span> <span class="mi">145</span><span class="p">,</span> <span class="s2">"chol"</span><span class="p">:</span> <span class="mi">233</span><span class="p">,</span> <span class="s2">"fbs"</span><span class="p">:</span> <span class="mi">1</span><span class="p">,</span> <span class="s2">"restecg"</span><span class="p">:</span> <span class="mi">2</span><span class="p">,</span> <span class="s2">"thalach"</span><span class="p">:</span> <span class="mi">150</span><span class="p">,</span> <span class="s2">"exang"</span><span class="p">:</span> <span class="mi">0</span><span class="p">,</span> <span class="s2">"oldpeak"</span><span class="p">:</span> <span class="mf">2.3</span><span class="p">,</span> <span class="s2">"slope"</span><span class="p">:</span> <span class="mi">3</span><span class="p">,</span> <span class="s2">"ca"</span><span class="p">:</span> <span class="mi">0</span><span class="p">,</span> <span class="s2">"thal"</span><span class="p">:</span> <span class="s2">"fixed"</span><span class="p">,</span> <span class="p">}</span> <span class="n">input_dict</span> <span class="o">=</span> <span class="p">{</span><span class="n">name</span><span class="p">:</span> <span class="n">tf</span><span class="o">.</span><span class="n">convert_to_tensor</span><span class="p">([</span><span class="n">value</span><span class="p">])</span> <span class="k">for</span> <span class="n">name</span><span class="p">,</span> <span class="n">value</span> <span class="ow">in</span> <span class="n">sample</span><span class="o">.</span><span class="n">items</span><span class="p">()}</span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">input_dict</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span> <span class="sa">f</span><span class="s2">"This particular patient had a </span><span class="si">{</span><span class="mi">100</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="n">predictions</span><span class="p">[</span><span class="mi">0</span><span class="p">][</span><span class="mi">0</span><span class="p">]</span><span class="si">:</span><span class="s2">.1f</span><span class="si">}</span><span class="s2"> "</span> <span class="s2">"percent probability of having a heart disease, "</span> <span class="s2">"as evaluated by our model."</span> <span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code> 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 252ms/step This particular patient had a 27.6 percent probability of having a heart disease, as evaluated by our model. </code></pre></div> </div> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#structured-data-classification-from-scratch'>Structured data classification from scratch</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-3'> <a href='#the-dataset'>The dataset</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#setup'>Setup</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#preparing-the-data'>Preparing the data</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#feature-preprocessing-with-keras-layers'>Feature preprocessing with Keras layers</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#build-a-model'>Build a model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#train-the-model'>Train the model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#inference-on-new-data'>Inference on new data</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>