CINXE.COM

How to train a Keras model on TFRecord files

<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/keras_recipes/tfrecord/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: How to train a Keras model on TFRecord files"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: How to train a Keras model on TFRecord files"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>How to train a Keras model on TFRecord files</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink active" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-sublink2" href="/examples/keras_recipes/parameter_efficient_finetuning_of_gemma_with_lora_and_qlora/">Parameter-efficient fine-tuning of Gemma with LoRA and QLoRA</a> <a class="nav-sublink2" href="/examples/keras_recipes/float8_training_and_inference_with_transformer/">Float8 training and inference with a simple Transformer model</a> <a class="nav-sublink2" href="/examples/keras_recipes/tf_serving/">Serving TensorFlow models with TFServing</a> <a class="nav-sublink2" href="/examples/keras_recipes/debugging_tips/">Keras debugging tips</a> <a class="nav-sublink2" href="/examples/keras_recipes/subclassing_conv_layers/">Customizing the convolution operation of a Conv2D layer</a> <a class="nav-sublink2" href="/examples/keras_recipes/trainer_pattern/">Trainer pattern</a> <a class="nav-sublink2" href="/examples/keras_recipes/endpoint_layer_pattern/">Endpoint layer pattern</a> <a class="nav-sublink2" href="/examples/keras_recipes/reproducibility_recipes/">Reproducibility in Keras Models</a> <a class="nav-sublink2" href="/examples/keras_recipes/tensorflow_numpy_models/">Writing Keras Models With TensorFlow NumPy</a> <a class="nav-sublink2" href="/examples/keras_recipes/antirectifier/">Simple custom layer example: Antirectifier</a> <a class="nav-sublink2" href="/examples/keras_recipes/sample_size_estimate/">Estimating required sample size for model training</a> <a class="nav-sublink2" href="/examples/keras_recipes/memory_efficient_embeddings/">Memory-efficient embeddings for recommendation systems</a> <a class="nav-sublink2" href="/examples/keras_recipes/creating_tfrecords/">Creating TFRecords</a> <a class="nav-sublink2" href="/examples/keras_recipes/packaging_keras_models_for_wide_distribution/">Packaging Keras models for wide distribution using Functional Subclassing</a> <a class="nav-sublink2" href="/examples/keras_recipes/approximating_non_function_mappings/">Approximating non-Function Mappings with Mixture Density Networks</a> <a class="nav-sublink2" href="/examples/keras_recipes/bayesian_neural_networks/">Probabilistic Bayesian Neural Networks</a> <a class="nav-sublink2" href="/examples/keras_recipes/better_knowledge_distillation/">Knowledge distillation recipes</a> <a class="nav-sublink2" href="/examples/keras_recipes/sklearn_metric_callbacks/">Evaluating and exporting scikit-learn metrics in a Keras callback</a> <a class="nav-sublink2 active" href="/examples/keras_recipes/tfrecord/">How to train a Keras model on TFRecord files</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparameter Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> <a class="nav-link" href="/keras_cv/" role="tab" aria-selected="">KerasCV: Computer Vision Workflows</a> <a class="nav-link" href="/keras_nlp/" role="tab" aria-selected="">KerasNLP: Natural Language Workflows</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/keras_recipes/'>Quick Keras Recipes</a> / How to train a Keras model on TFRecord files </div> <div class='k-content'> <h1 id="how-to-train-a-keras-model-on-tfrecord-files">How to train a Keras model on TFRecord files</h1> <p><strong>Author:</strong> Amy MiHyun Jang<br> <strong>Date created:</strong> 2020/07/29<br> <strong>Last modified:</strong> 2020/08/07<br></p> <div class='example_version_banner keras_2'>ⓘ This example uses Keras 2</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/keras_recipes/ipynb/tfrecord.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/examples/keras_recipes/tfrecord.py"><strong>GitHub source</strong></a></p> <p><strong>Description:</strong> Loading TFRecords for computer vision models.</p> <hr /> <h2 id="introduction--set-up">Introduction + Set Up</h2> <p>TFRecords store a sequence of binary records, read linearly. They are useful format for storing data because they can be read efficiently. Learn more about TFRecords <a href="https://www.tensorflow.org/tutorials/load_data/tfrecord">here</a>.</p> <p>We'll explore how we can easily load in TFRecords for our melanoma classifier.</p> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="nn">tf</span> <span class="kn">from</span> <span class="nn">functools</span> <span class="kn">import</span> <span class="n">partial</span> <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span> <span class="k">try</span><span class="p">:</span> <span class="n">tpu</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">distribute</span><span class="o">.</span><span class="n">cluster_resolver</span><span class="o">.</span><span class="n">TPUClusterResolver</span><span class="o">.</span><span class="n">connect</span><span class="p">()</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Device:&quot;</span><span class="p">,</span> <span class="n">tpu</span><span class="o">.</span><span class="n">master</span><span class="p">())</span> <span class="n">strategy</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">distribute</span><span class="o">.</span><span class="n">TPUStrategy</span><span class="p">(</span><span class="n">tpu</span><span class="p">)</span> <span class="k">except</span><span class="p">:</span> <span class="n">strategy</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">distribute</span><span class="o">.</span><span class="n">get_strategy</span><span class="p">()</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Number of replicas:&quot;</span><span class="p">,</span> <span class="n">strategy</span><span class="o">.</span><span class="n">num_replicas_in_sync</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Number of replicas: 8 </code></pre></div> </div> <p>We want a bigger batch size as our data is not balanced.</p> <div class="codehilite"><pre><span></span><code><span class="n">AUTOTUNE</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">AUTOTUNE</span> <span class="n">GCS_PATH</span> <span class="o">=</span> <span class="s2">&quot;gs://kds-b38ce1b823c3ae623f5691483dbaa0f0363f04b0d6a90b63cf69946e&quot;</span> <span class="n">BATCH_SIZE</span> <span class="o">=</span> <span class="mi">64</span> <span class="n">IMAGE_SIZE</span> <span class="o">=</span> <span class="p">[</span><span class="mi">1024</span><span class="p">,</span> <span class="mi">1024</span><span class="p">]</span> </code></pre></div> <hr /> <h2 id="load-the-data">Load the data</h2> <div class="codehilite"><pre><span></span><code><span class="n">FILENAMES</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">io</span><span class="o">.</span><span class="n">gfile</span><span class="o">.</span><span class="n">glob</span><span class="p">(</span><span class="n">GCS_PATH</span> <span class="o">+</span> <span class="s2">&quot;/tfrecords/train*.tfrec&quot;</span><span class="p">)</span> <span class="n">split_ind</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="mf">0.9</span> <span class="o">*</span> <span class="nb">len</span><span class="p">(</span><span class="n">FILENAMES</span><span class="p">))</span> <span class="n">TRAINING_FILENAMES</span><span class="p">,</span> <span class="n">VALID_FILENAMES</span> <span class="o">=</span> <span class="n">FILENAMES</span><span class="p">[:</span><span class="n">split_ind</span><span class="p">],</span> <span class="n">FILENAMES</span><span class="p">[</span><span class="n">split_ind</span><span class="p">:]</span> <span class="n">TEST_FILENAMES</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">io</span><span class="o">.</span><span class="n">gfile</span><span class="o">.</span><span class="n">glob</span><span class="p">(</span><span class="n">GCS_PATH</span> <span class="o">+</span> <span class="s2">&quot;/tfrecords/test*.tfrec&quot;</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Train TFRecord Files:&quot;</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">TRAINING_FILENAMES</span><span class="p">))</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Validation TFRecord Files:&quot;</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">VALID_FILENAMES</span><span class="p">))</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Test TFRecord Files:&quot;</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">TEST_FILENAMES</span><span class="p">))</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Train TFRecord Files: 14 Validation TFRecord Files: 2 Test TFRecord Files: 16 </code></pre></div> </div> <h3 id="decoding-the-data">Decoding the data</h3> <p>The images have to be converted to tensors so that it will be a valid input in our model. As images utilize an RBG scale, we specify 3 channels.</p> <p>We also reshape our data so that all of the images will be the same shape.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">decode_image</span><span class="p">(</span><span class="n">image</span><span class="p">):</span> <span class="n">image</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">decode_jpeg</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="n">channels</span><span class="o">=</span><span class="mi">3</span><span class="p">)</span> <span class="n">image</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="n">tf</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="n">image</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">image</span><span class="p">,</span> <span class="p">[</span><span class="o">*</span><span class="n">IMAGE_SIZE</span><span class="p">,</span> <span class="mi">3</span><span class="p">])</span> <span class="k">return</span> <span class="n">image</span> </code></pre></div> <p>As we load in our data, we need both our <code>X</code> and our <code>Y</code>. The X is our image; the model will find features and patterns in our image dataset. We want to predict Y, the probability that the lesion in the image is malignant. We will to through our TFRecords and parse out the image and the target values.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">read_tfrecord</span><span class="p">(</span><span class="n">example</span><span class="p">,</span> <span class="n">labeled</span><span class="p">):</span> <span class="n">tfrecord_format</span> <span class="o">=</span> <span class="p">(</span> <span class="p">{</span> <span class="s2">&quot;image&quot;</span><span class="p">:</span> <span class="n">tf</span><span class="o">.</span><span class="n">io</span><span class="o">.</span><span class="n">FixedLenFeature</span><span class="p">([],</span> <span class="n">tf</span><span class="o">.</span><span class="n">string</span><span class="p">),</span> <span class="s2">&quot;target&quot;</span><span class="p">:</span> <span class="n">tf</span><span class="o">.</span><span class="n">io</span><span class="o">.</span><span class="n">FixedLenFeature</span><span class="p">([],</span> <span class="n">tf</span><span class="o">.</span><span class="n">int64</span><span class="p">),</span> <span class="p">}</span> <span class="k">if</span> <span class="n">labeled</span> <span class="k">else</span> <span class="p">{</span><span class="s2">&quot;image&quot;</span><span class="p">:</span> <span class="n">tf</span><span class="o">.</span><span class="n">io</span><span class="o">.</span><span class="n">FixedLenFeature</span><span class="p">([],</span> <span class="n">tf</span><span class="o">.</span><span class="n">string</span><span class="p">),}</span> <span class="p">)</span> <span class="n">example</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">io</span><span class="o">.</span><span class="n">parse_single_example</span><span class="p">(</span><span class="n">example</span><span class="p">,</span> <span class="n">tfrecord_format</span><span class="p">)</span> <span class="n">image</span> <span class="o">=</span> <span class="n">decode_image</span><span class="p">(</span><span class="n">example</span><span class="p">[</span><span class="s2">&quot;image&quot;</span><span class="p">])</span> <span class="k">if</span> <span class="n">labeled</span><span class="p">:</span> <span class="n">label</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">example</span><span class="p">[</span><span class="s2">&quot;target&quot;</span><span class="p">],</span> <span class="n">tf</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span> <span class="k">return</span> <span class="n">image</span><span class="p">,</span> <span class="n">label</span> <span class="k">return</span> <span class="n">image</span> </code></pre></div> <h3 id="define-loading-methods">Define loading methods</h3> <p>Our dataset is not ordered in any meaningful way, so the order can be ignored when loading our dataset. By ignoring the order and reading files as soon as they come in, it will take a shorter time to load the data.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">load_dataset</span><span class="p">(</span><span class="n">filenames</span><span class="p">,</span> <span class="n">labeled</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span> <span class="n">ignore_order</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Options</span><span class="p">()</span> <span class="n">ignore_order</span><span class="o">.</span><span class="n">experimental_deterministic</span> <span class="o">=</span> <span class="kc">False</span> <span class="c1"># disable order, increase speed</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">TFRecordDataset</span><span class="p">(</span> <span class="n">filenames</span> <span class="p">)</span> <span class="c1"># automatically interleaves reads from multiple files</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">with_options</span><span class="p">(</span> <span class="n">ignore_order</span> <span class="p">)</span> <span class="c1"># uses data as soon as it streams in, rather than in its original order</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">map</span><span class="p">(</span> <span class="n">partial</span><span class="p">(</span><span class="n">read_tfrecord</span><span class="p">,</span> <span class="n">labeled</span><span class="o">=</span><span class="n">labeled</span><span class="p">),</span> <span class="n">num_parallel_calls</span><span class="o">=</span><span class="n">AUTOTUNE</span> <span class="p">)</span> <span class="c1"># returns a dataset of (image, label) pairs if labeled=True or just images if labeled=False</span> <span class="k">return</span> <span class="n">dataset</span> </code></pre></div> <p>We define the following function to get our different datasets.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">get_dataset</span><span class="p">(</span><span class="n">filenames</span><span class="p">,</span> <span class="n">labeled</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">load_dataset</span><span class="p">(</span><span class="n">filenames</span><span class="p">,</span> <span class="n">labeled</span><span class="o">=</span><span class="n">labeled</span><span class="p">)</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="mi">2048</span><span class="p">)</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">prefetch</span><span class="p">(</span><span class="n">buffer_size</span><span class="o">=</span><span class="n">AUTOTUNE</span><span class="p">)</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="n">BATCH_SIZE</span><span class="p">)</span> <span class="k">return</span> <span class="n">dataset</span> </code></pre></div> <h3 id="visualize-input-images">Visualize input images</h3> <div class="codehilite"><pre><span></span><code><span class="n">train_dataset</span> <span class="o">=</span> <span class="n">get_dataset</span><span class="p">(</span><span class="n">TRAINING_FILENAMES</span><span class="p">)</span> <span class="n">valid_dataset</span> <span class="o">=</span> <span class="n">get_dataset</span><span class="p">(</span><span class="n">VALID_FILENAMES</span><span class="p">)</span> <span class="n">test_dataset</span> <span class="o">=</span> <span class="n">get_dataset</span><span class="p">(</span><span class="n">TEST_FILENAMES</span><span class="p">,</span> <span class="n">labeled</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> <span class="n">image_batch</span><span class="p">,</span> <span class="n">label_batch</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="nb">iter</span><span class="p">(</span><span class="n">train_dataset</span><span class="p">))</span> <span class="k">def</span> <span class="nf">show_batch</span><span class="p">(</span><span class="n">image_batch</span><span class="p">,</span> <span class="n">label_batch</span><span class="p">):</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span> <span class="k">for</span> <span class="n">n</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">25</span><span class="p">):</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="n">n</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">image_batch</span><span class="p">[</span><span class="n">n</span><span class="p">]</span> <span class="o">/</span> <span class="mf">255.0</span><span class="p">)</span> <span class="k">if</span> <span class="n">label_batch</span><span class="p">[</span><span class="n">n</span><span class="p">]:</span> <span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="s2">&quot;MALIGNANT&quot;</span><span class="p">)</span> <span class="k">else</span><span class="p">:</span> <span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="s2">&quot;BENIGN&quot;</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">&quot;off&quot;</span><span class="p">)</span> <span class="n">show_batch</span><span class="p">(</span><span class="n">image_batch</span><span class="o">.</span><span class="n">numpy</span><span class="p">(),</span> <span class="n">label_batch</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span> </code></pre></div> <p><img alt="png" src="/img/examples/keras_recipes/tfrecord/tfrecord_16_0.png" /></p> <hr /> <h2 id="building-our-model">Building our model</h2> <h3 id="define-callbacks">Define callbacks</h3> <p>The following function allows for the model to change the learning rate as it runs each epoch.</p> <p>We can use callbacks to stop training when there are no improvements in the model. At the end of the training process, the model will restore the weights of its best iteration.</p> <div class="codehilite"><pre><span></span><code><span class="n">initial_learning_rate</span> <span class="o">=</span> <span class="mf">0.01</span> <span class="n">lr_schedule</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">schedules</span><span class="o">.</span><span class="n">ExponentialDecay</span><span class="p">(</span> <span class="n">initial_learning_rate</span><span class="p">,</span> <span class="n">decay_steps</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span> <span class="n">decay_rate</span><span class="o">=</span><span class="mf">0.96</span><span class="p">,</span> <span class="n">staircase</span><span class="o">=</span><span class="kc">True</span> <span class="p">)</span> <span class="n">checkpoint_cb</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">callbacks</span><span class="o">.</span><span class="n">ModelCheckpoint</span><span class="p">(</span> <span class="s2">&quot;melanoma_model.h5&quot;</span><span class="p">,</span> <span class="n">save_best_only</span><span class="o">=</span><span class="kc">True</span> <span class="p">)</span> <span class="n">early_stopping_cb</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">callbacks</span><span class="o">.</span><span class="n">EarlyStopping</span><span class="p">(</span> <span class="n">patience</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="n">restore_best_weights</span><span class="o">=</span><span class="kc">True</span> <span class="p">)</span> </code></pre></div> <h3 id="build-our-base-model">Build our base model</h3> <p>Transfer learning is a great way to reap the benefits of a well-trained model without having the train the model ourselves. For this notebook, we want to import the Xception model. A more in-depth analysis of transfer learning can be found <a href="https://keras.io/examples/vision/image_classification_efficientnet_fine_tuning/">here</a>.</p> <p>We do not want our metric to be <code>accuracy</code> because our data is imbalanced. For our example, we will be looking at the area under a ROC curve.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">make_model</span><span class="p">():</span> <span class="n">base_model</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">applications</span><span class="o">.</span><span class="n">Xception</span><span class="p">(</span> <span class="n">input_shape</span><span class="o">=</span><span class="p">(</span><span class="o">*</span><span class="n">IMAGE_SIZE</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">include_top</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">weights</span><span class="o">=</span><span class="s2">&quot;imagenet&quot;</span> <span class="p">)</span> <span class="n">base_model</span><span class="o">.</span><span class="n">trainable</span> <span class="o">=</span> <span class="kc">False</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Input</span><span class="p">([</span><span class="o">*</span><span class="n">IMAGE_SIZE</span><span class="p">,</span> <span class="mi">3</span><span class="p">])</span> <span class="n">x</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">applications</span><span class="o">.</span><span class="n">xception</span><span class="o">.</span><span class="n">preprocess_input</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">base_model</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">GlobalAveragePooling2D</span><span class="p">()(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">8</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">&quot;relu&quot;</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="mf">0.7</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">&quot;sigmoid&quot;</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="o">=</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="o">=</span><span class="n">outputs</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">learning_rate</span><span class="o">=</span><span class="n">lr_schedule</span><span class="p">),</span> <span class="n">loss</span><span class="o">=</span><span class="s2">&quot;binary_crossentropy&quot;</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">AUC</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">&quot;auc&quot;</span><span class="p">),</span> <span class="p">)</span> <span class="k">return</span> <span class="n">model</span> </code></pre></div> <hr /> <h2 id="train-the-model">Train the model</h2> <div class="codehilite"><pre><span></span><code><span class="k">with</span> <span class="n">strategy</span><span class="o">.</span><span class="n">scope</span><span class="p">():</span> <span class="n">model</span> <span class="o">=</span> <span class="n">make_model</span><span class="p">()</span> <span class="n">history</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span> <span class="n">train_dataset</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="n">valid_dataset</span><span class="p">,</span> <span class="n">callbacks</span><span class="o">=</span><span class="p">[</span><span class="n">checkpoint_cb</span><span class="p">,</span> <span class="n">early_stopping_cb</span><span class="p">],</span> <span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5 83689472/83683744 [==============================] - 3s 0us/step Epoch 1/2 454/454 [==============================] - 525s 1s/step - loss: 0.1895 - auc: 0.5841 - val_loss: 0.0825 - val_auc: 0.8109 Epoch 2/2 454/454 [==============================] - 118s 260ms/step - loss: 0.1063 - auc: 0.5994 - val_loss: 0.0861 - val_auc: 0.8336 </code></pre></div> </div> <hr /> <h2 id="predict-results">Predict results</h2> <p>We'll use our model to predict results for our test dataset images. Values closer to <code>0</code> are more likely to be benign and values closer to <code>1</code> are more likely to be malignant.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">show_batch_predictions</span><span class="p">(</span><span class="n">image_batch</span><span class="p">):</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span> <span class="k">for</span> <span class="n">n</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">25</span><span class="p">):</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="n">n</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">image_batch</span><span class="p">[</span><span class="n">n</span><span class="p">]</span> <span class="o">/</span> <span class="mf">255.0</span><span class="p">)</span> <span class="n">img_array</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">image_batch</span><span class="p">[</span><span class="n">n</span><span class="p">],</span> <span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">img_array</span><span class="p">)[</span><span class="mi">0</span><span class="p">])</span> <span class="n">plt</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">&quot;off&quot;</span><span class="p">)</span> <span class="n">image_batch</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="nb">iter</span><span class="p">(</span><span class="n">test_dataset</span><span class="p">))</span> <span class="n">show_batch_predictions</span><span class="p">(</span><span class="n">image_batch</span><span class="p">)</span> </code></pre></div> <p><img alt="png" src="/img/examples/keras_recipes/tfrecord/tfrecord_25_1.png" /></p> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#how-to-train-a-keras-model-on-tfrecord-files'>How to train a Keras model on TFRecord files</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction--set-up'>Introduction + Set Up</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#load-the-data'>Load the data</a> </div> <div class='k-outline-depth-3'> <a href='#decoding-the-data'>Decoding the data</a> </div> <div class='k-outline-depth-3'> <a href='#define-loading-methods'>Define loading methods</a> </div> <div class='k-outline-depth-3'> <a href='#visualize-input-images'>Visualize input images</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#building-our-model'>Building our model</a> </div> <div class='k-outline-depth-3'> <a href='#define-callbacks'>Define callbacks</a> </div> <div class='k-outline-depth-3'> <a href='#build-our-base-model'>Build our base model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#train-the-model'>Train the model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#predict-results'>Predict results</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>

Pages: 1 2 3 4 5 6 7 8 9 10