CINXE.COM
Estimating required sample size for model training
<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/keras_recipes/sample_size_estimate/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Estimating required sample size for model training"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Estimating required sample size for model training"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Estimating required sample size for model training</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink active" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-sublink2" href="/examples/keras_recipes/parameter_efficient_finetuning_of_gemma_with_lora_and_qlora/">Parameter-efficient fine-tuning of Gemma with LoRA and QLoRA</a> <a class="nav-sublink2" href="/examples/keras_recipes/float8_training_and_inference_with_transformer/">Float8 training and inference with a simple Transformer model</a> <a class="nav-sublink2" href="/examples/keras_recipes/tf_serving/">Serving TensorFlow models with TFServing</a> <a class="nav-sublink2" href="/examples/keras_recipes/debugging_tips/">Keras debugging tips</a> <a class="nav-sublink2" href="/examples/keras_recipes/subclassing_conv_layers/">Customizing the convolution operation of a Conv2D layer</a> <a class="nav-sublink2" href="/examples/keras_recipes/trainer_pattern/">Trainer pattern</a> <a class="nav-sublink2" href="/examples/keras_recipes/endpoint_layer_pattern/">Endpoint layer pattern</a> <a class="nav-sublink2" href="/examples/keras_recipes/reproducibility_recipes/">Reproducibility in Keras Models</a> <a class="nav-sublink2" href="/examples/keras_recipes/tensorflow_numpy_models/">Writing Keras Models With TensorFlow NumPy</a> <a class="nav-sublink2" href="/examples/keras_recipes/antirectifier/">Simple custom layer example: Antirectifier</a> <a class="nav-sublink2 active" href="/examples/keras_recipes/sample_size_estimate/">Estimating required sample size for model training</a> <a class="nav-sublink2" href="/examples/keras_recipes/memory_efficient_embeddings/">Memory-efficient embeddings for recommendation systems</a> <a class="nav-sublink2" href="/examples/keras_recipes/creating_tfrecords/">Creating TFRecords</a> <a class="nav-sublink2" href="/examples/keras_recipes/packaging_keras_models_for_wide_distribution/">Packaging Keras models for wide distribution using Functional Subclassing</a> <a class="nav-sublink2" href="/examples/keras_recipes/approximating_non_function_mappings/">Approximating non-Function Mappings with Mixture Density Networks</a> <a class="nav-sublink2" href="/examples/keras_recipes/bayesian_neural_networks/">Probabilistic Bayesian Neural Networks</a> <a class="nav-sublink2" href="/examples/keras_recipes/better_knowledge_distillation/">Knowledge distillation recipes</a> <a class="nav-sublink2" href="/examples/keras_recipes/sklearn_metric_callbacks/">Evaluating and exporting scikit-learn metrics in a Keras callback</a> <a class="nav-sublink2" href="/examples/keras_recipes/tfrecord/">How to train a Keras model on TFRecord files</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparameter Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> <a class="nav-link" href="/keras_cv/" role="tab" aria-selected="">KerasCV: Computer Vision Workflows</a> <a class="nav-link" href="/keras_nlp/" role="tab" aria-selected="">KerasNLP: Natural Language Workflows</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/keras_recipes/'>Quick Keras Recipes</a> / Estimating required sample size for model training </div> <div class='k-content'> <h1 id="estimating-required-sample-size-for-model-training">Estimating required sample size for model training</h1> <p><strong>Author:</strong> <a href="https://twitter.com/JacoVerster">JacoVerster</a><br> <strong>Date created:</strong> 2021/05/20<br> <strong>Last modified:</strong> 2021/06/06<br> <strong>Description:</strong> Modeling the relationship between training set size and model accuracy.</p> <div class='example_version_banner keras_3'>ⓘ This example uses Keras 3</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/keras_recipes/ipynb/sample_size_estimate.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team/keras-io/blob/master/examples/keras_recipes/sample_size_estimate.py"><strong>GitHub source</strong></a></p> <h1 id="introduction">Introduction</h1> <p>In many real-world scenarios, the amount image data available to train a deep learning model is limited. This is especially true in the medical imaging domain, where dataset creation is costly. One of the first questions that usually comes up when approaching a new problem is: <strong>"how many images will we need to train a good enough machine learning model?"</strong></p> <p>In most cases, a small set of samples is available, and we can use it to model the relationship between training data size and model performance. Such a model can be used to estimate the optimal number of images needed to arrive at a sample size that would achieve the required model performance.</p> <p>A systematic review of <a href="https://www.researchgate.net/publication/335779941_Sample-Size_Determination_Methodologies_for_Machine_Learning_in_Medical_Imaging_Research_A_Systematic_Review">Sample-Size Determination Methodologies</a> by Balki et al. provides examples of several sample-size determination methods. In this example, a balanced subsampling scheme is used to determine the optimal sample size for our model. This is done by selecting a random subsample consisting of Y number of images and training the model using the subsample. The model is then evaluated on an independent test set. This process is repeated N times for each subsample with replacement to allow for the construction of a mean and confidence interval for the observed performance.</p> <hr /> <h2 id="setup">Setup</h2> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">os</span> <span class="n">os</span><span class="o">.</span><span class="n">environ</span><span class="p">[</span><span class="s2">"KERAS_BACKEND"</span><span class="p">]</span> <span class="o">=</span> <span class="s2">"tensorflow"</span> <span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span> <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> <span class="kn">import</span> <span class="nn">tensorflow</span> <span class="k">as</span> <span class="nn">tf</span> <span class="kn">import</span> <span class="nn">keras</span> <span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">layers</span> <span class="kn">import</span> <span class="nn">tensorflow_datasets</span> <span class="k">as</span> <span class="nn">tfds</span> <span class="c1"># Define seed and fixed variables</span> <span class="n">seed</span> <span class="o">=</span> <span class="mi">42</span> <span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">set_random_seed</span><span class="p">(</span><span class="n">seed</span><span class="p">)</span> <span class="n">AUTO</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">AUTOTUNE</span> </code></pre></div> <hr /> <h2 id="load-tensorflow-dataset-and-convert-to-numpy-arrays">Load TensorFlow dataset and convert to NumPy arrays</h2> <p>We'll be using the <a href="https://www.tensorflow.org/datasets/catalog/tf_flowers">TF Flowers dataset</a>.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Specify dataset parameters</span> <span class="n">dataset_name</span> <span class="o">=</span> <span class="s2">"tf_flowers"</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="mi">64</span> <span class="n">image_size</span> <span class="o">=</span> <span class="p">(</span><span class="mi">224</span><span class="p">,</span> <span class="mi">224</span><span class="p">)</span> <span class="c1"># Load data from tfds and split 10% off for a test set</span> <span class="p">(</span><span class="n">train_data</span><span class="p">,</span> <span class="n">test_data</span><span class="p">),</span> <span class="n">ds_info</span> <span class="o">=</span> <span class="n">tfds</span><span class="o">.</span><span class="n">load</span><span class="p">(</span> <span class="n">dataset_name</span><span class="p">,</span> <span class="n">split</span><span class="o">=</span><span class="p">[</span><span class="s2">"train[:90%]"</span><span class="p">,</span> <span class="s2">"train[90%:]"</span><span class="p">],</span> <span class="n">shuffle_files</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">as_supervised</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">with_info</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="p">)</span> <span class="c1"># Extract number of classes and list of class names</span> <span class="n">num_classes</span> <span class="o">=</span> <span class="n">ds_info</span><span class="o">.</span><span class="n">features</span><span class="p">[</span><span class="s2">"label"</span><span class="p">]</span><span class="o">.</span><span class="n">num_classes</span> <span class="n">class_names</span> <span class="o">=</span> <span class="n">ds_info</span><span class="o">.</span><span class="n">features</span><span class="p">[</span><span class="s2">"label"</span><span class="p">]</span><span class="o">.</span><span class="n">names</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Number of classes: </span><span class="si">{</span><span class="n">num_classes</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Class names: </span><span class="si">{</span><span class="n">class_names</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> <span class="c1"># Convert datasets to NumPy arrays</span> <span class="k">def</span> <span class="nf">dataset_to_array</span><span class="p">(</span><span class="n">dataset</span><span class="p">,</span> <span class="n">image_size</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">):</span> <span class="n">images</span><span class="p">,</span> <span class="n">labels</span> <span class="o">=</span> <span class="p">[],</span> <span class="p">[]</span> <span class="k">for</span> <span class="n">img</span><span class="p">,</span> <span class="n">lab</span> <span class="ow">in</span> <span class="n">dataset</span><span class="o">.</span><span class="n">as_numpy_iterator</span><span class="p">():</span> <span class="n">images</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">image</span><span class="o">.</span><span class="n">resize</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">image_size</span><span class="p">)</span><span class="o">.</span><span class="n">numpy</span><span class="p">())</span> <span class="n">labels</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">one_hot</span><span class="p">(</span><span class="n">lab</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">))</span> <span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">images</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">labels</span><span class="p">)</span> <span class="n">img_train</span><span class="p">,</span> <span class="n">label_train</span> <span class="o">=</span> <span class="n">dataset_to_array</span><span class="p">(</span><span class="n">train_data</span><span class="p">,</span> <span class="n">image_size</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">)</span> <span class="n">img_test</span><span class="p">,</span> <span class="n">label_test</span> <span class="o">=</span> <span class="n">dataset_to_array</span><span class="p">(</span><span class="n">test_data</span><span class="p">,</span> <span class="n">image_size</span><span class="p">,</span> <span class="n">num_classes</span><span class="p">)</span> <span class="n">num_train_samples</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">img_train</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Number of training samples: </span><span class="si">{</span><span class="n">num_train_samples</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Number of classes: 5 Class names: ['dandelion', 'daisy', 'tulips', 'sunflowers', 'roses'] Number of training samples: 3303 </code></pre></div> </div> <hr /> <h2 id="plot-a-few-examples-from-the-test-set">Plot a few examples from the test set</h2> <div class="codehilite"><pre><span></span><code><span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">16</span><span class="p">,</span> <span class="mi">12</span><span class="p">))</span> <span class="k">for</span> <span class="n">n</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">30</span><span class="p">):</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="n">n</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">img_test</span><span class="p">[</span><span class="n">n</span><span class="p">]</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">"uint8"</span><span class="p">))</span> <span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">class_names</span><span class="p">)[</span><span class="n">label_test</span><span class="p">[</span><span class="n">n</span><span class="p">]</span> <span class="o">==</span> <span class="kc">True</span><span class="p">][</span><span class="mi">0</span><span class="p">])</span> <span class="n">plt</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">"off"</span><span class="p">)</span> </code></pre></div> <p><img alt="png" src="/img/examples/keras_recipes/sample_size_estimate/sample_size_estimate_7_0.png" /></p> <hr /> <h2 id="augmentation">Augmentation</h2> <p>Define image augmentation using keras preprocessing layers and apply them to the training set.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Define image augmentation model</span> <span class="n">image_augmentation</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span> <span class="p">[</span> <span class="n">layers</span><span class="o">.</span><span class="n">RandomFlip</span><span class="p">(</span><span class="n">mode</span><span class="o">=</span><span class="s2">"horizontal"</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">RandomRotation</span><span class="p">(</span><span class="n">factor</span><span class="o">=</span><span class="mf">0.1</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">RandomZoom</span><span class="p">(</span><span class="n">height_factor</span><span class="o">=</span><span class="p">(</span><span class="o">-</span><span class="mf">0.1</span><span class="p">,</span> <span class="o">-</span><span class="mi">0</span><span class="p">)),</span> <span class="n">layers</span><span class="o">.</span><span class="n">RandomContrast</span><span class="p">(</span><span class="n">factor</span><span class="o">=</span><span class="mf">0.1</span><span class="p">),</span> <span class="p">],</span> <span class="p">)</span> <span class="c1"># Apply the augmentations to the training images and plot a few examples</span> <span class="n">img_train</span> <span class="o">=</span> <span class="n">image_augmentation</span><span class="p">(</span><span class="n">img_train</span><span class="p">)</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> <span class="n">plt</span><span class="o">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">16</span><span class="p">,</span> <span class="mi">12</span><span class="p">))</span> <span class="k">for</span> <span class="n">n</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">30</span><span class="p">):</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplot</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="mi">6</span><span class="p">,</span> <span class="n">n</span> <span class="o">+</span> <span class="mi">1</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">img_train</span><span class="p">[</span><span class="n">n</span><span class="p">]</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="s2">"uint8"</span><span class="p">))</span> <span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">class_names</span><span class="p">)[</span><span class="n">label_train</span><span class="p">[</span><span class="n">n</span><span class="p">]</span> <span class="o">==</span> <span class="kc">True</span><span class="p">][</span><span class="mi">0</span><span class="p">])</span> <span class="n">plt</span><span class="o">.</span><span class="n">axis</span><span class="p">(</span><span class="s2">"off"</span><span class="p">)</span> </code></pre></div> <p><img alt="png" src="/img/examples/keras_recipes/sample_size_estimate/sample_size_estimate_9_0.png" /></p> <hr /> <h2 id="define-model-building-amp-training-functions">Define model building & training functions</h2> <p>We create a few convenience functions to build a transfer-learning model, compile and train it and unfreeze layers for fine-tuning.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">build_model</span><span class="p">(</span><span class="n">num_classes</span><span class="p">,</span> <span class="n">img_size</span><span class="o">=</span><span class="n">image_size</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">top_dropout</span><span class="o">=</span><span class="mf">0.3</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Creates a classifier based on pre-trained MobileNetV2.</span> <span class="sd"> Arguments:</span> <span class="sd"> num_classes: Int, number of classese to use in the softmax layer.</span> <span class="sd"> img_size: Int, square size of input images (defaults is 224).</span> <span class="sd"> top_dropout: Int, value for dropout layer (defaults is 0.3).</span> <span class="sd"> Returns:</span> <span class="sd"> Uncompiled Keras model.</span> <span class="sd"> """</span> <span class="c1"># Create input and pre-processing layers for MobileNetV2</span> <span class="n">inputs</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Input</span><span class="p">(</span><span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">img_size</span><span class="p">,</span> <span class="n">img_size</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Rescaling</span><span class="p">(</span><span class="n">scale</span><span class="o">=</span><span class="mf">1.0</span> <span class="o">/</span> <span class="mf">127.5</span><span class="p">,</span> <span class="n">offset</span><span class="o">=-</span><span class="mi">1</span><span class="p">)(</span><span class="n">inputs</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">applications</span><span class="o">.</span><span class="n">MobileNetV2</span><span class="p">(</span> <span class="n">include_top</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">weights</span><span class="o">=</span><span class="s2">"imagenet"</span><span class="p">,</span> <span class="n">input_tensor</span><span class="o">=</span><span class="n">x</span> <span class="p">)</span> <span class="c1"># Freeze the pretrained weights</span> <span class="n">model</span><span class="o">.</span><span class="n">trainable</span> <span class="o">=</span> <span class="kc">False</span> <span class="c1"># Rebuild top</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">GlobalAveragePooling2D</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">"avg_pool"</span><span class="p">)(</span><span class="n">model</span><span class="o">.</span><span class="n">output</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">top_dropout</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">outputs</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">num_classes</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">"softmax"</span><span class="p">)(</span><span class="n">x</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Trainable weights:"</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">trainable_weights</span><span class="p">))</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Non_trainable weights:"</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">non_trainable_weights</span><span class="p">))</span> <span class="k">return</span> <span class="n">model</span> <span class="k">def</span> <span class="nf">compile_and_train</span><span class="p">(</span> <span class="n">model</span><span class="p">,</span> <span class="n">training_data</span><span class="p">,</span> <span class="n">training_labels</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">AUC</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">"auc"</span><span class="p">),</span> <span class="s2">"acc"</span><span class="p">],</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">Adam</span><span class="p">(),</span> <span class="n">patience</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="p">):</span> <span class="w"> </span><span class="sd">"""Compiles and trains the model.</span> <span class="sd"> Arguments:</span> <span class="sd"> model: Uncompiled Keras model.</span> <span class="sd"> training_data: NumPy Array, training data.</span> <span class="sd"> training_labels: NumPy Array, training labels.</span> <span class="sd"> metrics: Keras/TF metrics, requires at least 'auc' metric (default is</span> <span class="sd"> `[keras.metrics.AUC(name='auc'), 'acc']`).</span> <span class="sd"> optimizer: Keras/TF optimizer (defaults is `keras.optimizers.Adam()).</span> <span class="sd"> patience: Int, epochsfor EarlyStopping patience (defaults is 5).</span> <span class="sd"> epochs: Int, number of epochs to train (default is 5).</span> <span class="sd"> Returns:</span> <span class="sd"> Training history for trained Keras model.</span> <span class="sd"> """</span> <span class="n">stopper</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">callbacks</span><span class="o">.</span><span class="n">EarlyStopping</span><span class="p">(</span> <span class="n">monitor</span><span class="o">=</span><span class="s2">"val_auc"</span><span class="p">,</span> <span class="n">mode</span><span class="o">=</span><span class="s2">"max"</span><span class="p">,</span> <span class="n">min_delta</span><span class="o">=</span><span class="mi">0</span><span class="p">,</span> <span class="n">patience</span><span class="o">=</span><span class="n">patience</span><span class="p">,</span> <span class="n">verbose</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">restore_best_weights</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">loss</span><span class="o">=</span><span class="s2">"categorical_crossentropy"</span><span class="p">,</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">optimizer</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="n">metrics</span><span class="p">)</span> <span class="n">history</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span> <span class="n">x</span><span class="o">=</span><span class="n">training_data</span><span class="p">,</span> <span class="n">y</span><span class="o">=</span><span class="n">training_labels</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="n">batch_size</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="n">epochs</span><span class="p">,</span> <span class="n">validation_split</span><span class="o">=</span><span class="mf">0.1</span><span class="p">,</span> <span class="n">callbacks</span><span class="o">=</span><span class="p">[</span><span class="n">stopper</span><span class="p">],</span> <span class="p">)</span> <span class="k">return</span> <span class="n">history</span> <span class="k">def</span> <span class="nf">unfreeze</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">block_name</span><span class="p">,</span> <span class="n">verbose</span><span class="o">=</span><span class="mi">0</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Unfreezes Keras model layers.</span> <span class="sd"> Arguments:</span> <span class="sd"> model: Keras model.</span> <span class="sd"> block_name: Str, layer name for example block_name = 'block4'.</span> <span class="sd"> Checks if supplied string is in the layer name.</span> <span class="sd"> verbose: Int, 0 means silent, 1 prints out layers trainability status.</span> <span class="sd"> Returns:</span> <span class="sd"> Keras model with all layers after (and including) the specified</span> <span class="sd"> block_name to trainable, excluding BatchNormalization layers.</span> <span class="sd"> """</span> <span class="c1"># Unfreeze from block_name onwards</span> <span class="n">set_trainable</span> <span class="o">=</span> <span class="kc">False</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="n">model</span><span class="o">.</span><span class="n">layers</span><span class="p">:</span> <span class="k">if</span> <span class="n">block_name</span> <span class="ow">in</span> <span class="n">layer</span><span class="o">.</span><span class="n">name</span><span class="p">:</span> <span class="n">set_trainable</span> <span class="o">=</span> <span class="kc">True</span> <span class="k">if</span> <span class="n">set_trainable</span> <span class="ow">and</span> <span class="ow">not</span> <span class="nb">isinstance</span><span class="p">(</span><span class="n">layer</span><span class="p">,</span> <span class="n">layers</span><span class="o">.</span><span class="n">BatchNormalization</span><span class="p">):</span> <span class="n">layer</span><span class="o">.</span><span class="n">trainable</span> <span class="o">=</span> <span class="kc">True</span> <span class="k">if</span> <span class="n">verbose</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span> <span class="nb">print</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="s2">"trainable"</span><span class="p">)</span> <span class="k">else</span><span class="p">:</span> <span class="k">if</span> <span class="n">verbose</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span> <span class="nb">print</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">name</span><span class="p">,</span> <span class="s2">"NOT trainable"</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Trainable weights:"</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">trainable_weights</span><span class="p">))</span> <span class="nb">print</span><span class="p">(</span><span class="s2">"Non-trainable weights:"</span><span class="p">,</span> <span class="nb">len</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">non_trainable_weights</span><span class="p">))</span> <span class="k">return</span> <span class="n">model</span> </code></pre></div> <hr /> <h2 id="define-iterative-training-function">Define iterative training function</h2> <p>To train a model over several subsample sets we need to create an iterative training function.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">train_model</span><span class="p">(</span><span class="n">training_data</span><span class="p">,</span> <span class="n">training_labels</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Trains the model as follows:</span> <span class="sd"> - Trains only the top layers for 10 epochs.</span> <span class="sd"> - Unfreezes deeper layers.</span> <span class="sd"> - Train for 20 more epochs.</span> <span class="sd"> Arguments:</span> <span class="sd"> training_data: NumPy Array, training data.</span> <span class="sd"> training_labels: NumPy Array, training labels.</span> <span class="sd"> Returns:</span> <span class="sd"> Model accuracy.</span> <span class="sd"> """</span> <span class="n">model</span> <span class="o">=</span> <span class="n">build_model</span><span class="p">(</span><span class="n">num_classes</span><span class="p">)</span> <span class="c1"># Compile and train top layers</span> <span class="n">history</span> <span class="o">=</span> <span class="n">compile_and_train</span><span class="p">(</span> <span class="n">model</span><span class="p">,</span> <span class="n">training_data</span><span class="p">,</span> <span class="n">training_labels</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">AUC</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">"auc"</span><span class="p">),</span> <span class="s2">"acc"</span><span class="p">],</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">Adam</span><span class="p">(),</span> <span class="n">patience</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span> <span class="p">)</span> <span class="c1"># Unfreeze model from block 10 onwards</span> <span class="n">model</span> <span class="o">=</span> <span class="n">unfreeze</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="s2">"block_10"</span><span class="p">)</span> <span class="c1"># Compile and train for 20 epochs with a lower learning rate</span> <span class="n">fine_tune_epochs</span> <span class="o">=</span> <span class="mi">20</span> <span class="n">total_epochs</span> <span class="o">=</span> <span class="n">history</span><span class="o">.</span><span class="n">epoch</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="n">fine_tune_epochs</span> <span class="n">history_fine</span> <span class="o">=</span> <span class="n">compile_and_train</span><span class="p">(</span> <span class="n">model</span><span class="p">,</span> <span class="n">training_data</span><span class="p">,</span> <span class="n">training_labels</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="n">keras</span><span class="o">.</span><span class="n">metrics</span><span class="o">.</span><span class="n">AUC</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s2">"auc"</span><span class="p">),</span> <span class="s2">"acc"</span><span class="p">],</span> <span class="n">optimizer</span><span class="o">=</span><span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">Adam</span><span class="p">(</span><span class="n">learning_rate</span><span class="o">=</span><span class="mf">1e-4</span><span class="p">),</span> <span class="n">patience</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="n">total_epochs</span><span class="p">,</span> <span class="p">)</span> <span class="c1"># Calculate model accuracy on the test set</span> <span class="n">_</span><span class="p">,</span> <span class="n">_</span><span class="p">,</span> <span class="n">acc</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">img_test</span><span class="p">,</span> <span class="n">label_test</span><span class="p">)</span> <span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">round</span><span class="p">(</span><span class="n">acc</span><span class="p">,</span> <span class="mi">4</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="train-models-iteratively">Train models iteratively</h2> <p>Now that we have model building functions and supporting iterative functions we can train the model over several subsample splits.</p> <ul> <li>We select the subsample splits as 5%, 10%, 25% and 50% of the downloaded dataset. We pretend that only 50% of the actual data is available at present.</li> <li>We train the model 5 times from scratch at each split and record the accuracy values.</li> </ul> <p>Note that this trains 20 models and will take some time. Make sure you have a GPU runtime active.</p> <p>To keep this example lightweight, sample data from a previous training run is provided.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">train_iteratively</span><span class="p">(</span><span class="n">sample_splits</span><span class="o">=</span><span class="p">[</span><span class="mf">0.05</span><span class="p">,</span> <span class="mf">0.1</span><span class="p">,</span> <span class="mf">0.25</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">],</span> <span class="n">iter_per_split</span><span class="o">=</span><span class="mi">5</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Trains a model iteratively over several sample splits.</span> <span class="sd"> Arguments:</span> <span class="sd"> sample_splits: List/NumPy array, contains fractions of the trainins set</span> <span class="sd"> to train over.</span> <span class="sd"> iter_per_split: Int, number of times to train a model per sample split.</span> <span class="sd"> Returns:</span> <span class="sd"> Training accuracy for all splits and iterations and the number of samples</span> <span class="sd"> used for training at each split.</span> <span class="sd"> """</span> <span class="c1"># Train all the sample models and calculate accuracy</span> <span class="n">train_acc</span> <span class="o">=</span> <span class="p">[]</span> <span class="n">sample_sizes</span> <span class="o">=</span> <span class="p">[]</span> <span class="k">for</span> <span class="n">fraction</span> <span class="ow">in</span> <span class="n">sample_splits</span><span class="p">:</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Fraction split: </span><span class="si">{</span><span class="n">fraction</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> <span class="c1"># Repeat training 3 times for each sample size</span> <span class="n">sample_accuracy</span> <span class="o">=</span> <span class="p">[]</span> <span class="n">num_samples</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">num_train_samples</span> <span class="o">*</span> <span class="n">fraction</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">iter_per_split</span><span class="p">):</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Run </span><span class="si">{</span><span class="n">i</span><span class="o">+</span><span class="mi">1</span><span class="si">}</span><span class="s2"> out of </span><span class="si">{</span><span class="n">iter_per_split</span><span class="si">}</span><span class="s2">:"</span><span class="p">)</span> <span class="c1"># Create fractional subsets</span> <span class="n">rand_idx</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="n">num_train_samples</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="n">num_samples</span><span class="p">)</span> <span class="n">train_img_subset</span> <span class="o">=</span> <span class="n">img_train</span><span class="p">[</span><span class="n">rand_idx</span><span class="p">,</span> <span class="p">:]</span> <span class="n">train_label_subset</span> <span class="o">=</span> <span class="n">label_train</span><span class="p">[</span><span class="n">rand_idx</span><span class="p">,</span> <span class="p">:]</span> <span class="c1"># Train model and calculate accuracy</span> <span class="n">accuracy</span> <span class="o">=</span> <span class="n">train_model</span><span class="p">(</span><span class="n">train_img_subset</span><span class="p">,</span> <span class="n">train_label_subset</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Accuracy: </span><span class="si">{</span><span class="n">accuracy</span><span class="si">}</span><span class="s2">"</span><span class="p">)</span> <span class="n">sample_accuracy</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">accuracy</span><span class="p">)</span> <span class="n">train_acc</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">sample_accuracy</span><span class="p">)</span> <span class="n">sample_sizes</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">num_samples</span><span class="p">)</span> <span class="k">return</span> <span class="n">train_acc</span><span class="p">,</span> <span class="n">sample_sizes</span> <span class="c1"># Running the above function produces the following outputs</span> <span class="n">train_acc</span> <span class="o">=</span> <span class="p">[</span> <span class="p">[</span><span class="mf">0.8202</span><span class="p">,</span> <span class="mf">0.7466</span><span class="p">,</span> <span class="mf">0.8011</span><span class="p">,</span> <span class="mf">0.8447</span><span class="p">,</span> <span class="mf">0.8229</span><span class="p">],</span> <span class="p">[</span><span class="mf">0.861</span><span class="p">,</span> <span class="mf">0.8774</span><span class="p">,</span> <span class="mf">0.8501</span><span class="p">,</span> <span class="mf">0.8937</span><span class="p">,</span> <span class="mf">0.891</span><span class="p">],</span> <span class="p">[</span><span class="mf">0.891</span><span class="p">,</span> <span class="mf">0.9237</span><span class="p">,</span> <span class="mf">0.8856</span><span class="p">,</span> <span class="mf">0.9101</span><span class="p">,</span> <span class="mf">0.891</span><span class="p">],</span> <span class="p">[</span><span class="mf">0.8937</span><span class="p">,</span> <span class="mf">0.9373</span><span class="p">,</span> <span class="mf">0.9128</span><span class="p">,</span> <span class="mf">0.8719</span><span class="p">,</span> <span class="mf">0.9128</span><span class="p">],</span> <span class="p">]</span> <span class="n">sample_sizes</span> <span class="o">=</span> <span class="p">[</span><span class="mi">165</span><span class="p">,</span> <span class="mi">330</span><span class="p">,</span> <span class="mi">825</span><span class="p">,</span> <span class="mi">1651</span><span class="p">]</span> </code></pre></div> <hr /> <h2 id="learning-curve">Learning curve</h2> <p>We now plot the learning curve by fitting an exponential curve through the mean accuracy points. We use TF to fit an exponential function through the data.</p> <p>We then extrapolate the learning curve to the predict the accuracy of a model trained on the whole training set.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">fit_and_predict</span><span class="p">(</span><span class="n">train_acc</span><span class="p">,</span> <span class="n">sample_sizes</span><span class="p">,</span> <span class="n">pred_sample_size</span><span class="p">):</span> <span class="w"> </span><span class="sd">"""Fits a learning curve to model training accuracy results.</span> <span class="sd"> Arguments:</span> <span class="sd"> train_acc: List/Numpy Array, training accuracy for all model</span> <span class="sd"> training splits and iterations.</span> <span class="sd"> sample_sizes: List/Numpy array, number of samples used for training at</span> <span class="sd"> each split.</span> <span class="sd"> pred_sample_size: Int, sample size to predict model accuracy based on</span> <span class="sd"> fitted learning curve.</span> <span class="sd"> """</span> <span class="n">x</span> <span class="o">=</span> <span class="n">sample_sizes</span> <span class="n">mean_acc</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">convert_to_tensor</span><span class="p">([</span><span class="n">np</span><span class="o">.</span><span class="n">mean</span><span class="p">(</span><span class="n">i</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">train_acc</span><span class="p">])</span> <span class="n">error</span> <span class="o">=</span> <span class="p">[</span><span class="n">np</span><span class="o">.</span><span class="n">std</span><span class="p">(</span><span class="n">i</span><span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="n">train_acc</span><span class="p">]</span> <span class="c1"># Define mean squared error cost and exponential curve fit functions</span> <span class="n">mse</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">MeanSquaredError</span><span class="p">()</span> <span class="k">def</span> <span class="nf">exp_func</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">):</span> <span class="k">return</span> <span class="n">a</span> <span class="o">*</span> <span class="n">x</span><span class="o">**</span><span class="n">b</span> <span class="c1"># Define variables, learning rate and number of epochs for fitting with TF</span> <span class="n">a</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">Variable</span><span class="p">(</span><span class="mf">0.0</span><span class="p">)</span> <span class="n">b</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">Variable</span><span class="p">(</span><span class="mf">0.0</span><span class="p">)</span> <span class="n">learning_rate</span> <span class="o">=</span> <span class="mf">0.01</span> <span class="n">training_epochs</span> <span class="o">=</span> <span class="mi">5000</span> <span class="c1"># Fit the exponential function to the data</span> <span class="k">for</span> <span class="n">epoch</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">training_epochs</span><span class="p">):</span> <span class="k">with</span> <span class="n">tf</span><span class="o">.</span><span class="n">GradientTape</span><span class="p">()</span> <span class="k">as</span> <span class="n">tape</span><span class="p">:</span> <span class="n">y_pred</span> <span class="o">=</span> <span class="n">exp_func</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span> <span class="n">cost_function</span> <span class="o">=</span> <span class="n">mse</span><span class="p">(</span><span class="n">y_pred</span><span class="p">,</span> <span class="n">mean_acc</span><span class="p">)</span> <span class="c1"># Get gradients and compute adjusted weights</span> <span class="n">gradients</span> <span class="o">=</span> <span class="n">tape</span><span class="o">.</span><span class="n">gradient</span><span class="p">(</span><span class="n">cost_function</span><span class="p">,</span> <span class="p">[</span><span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">])</span> <span class="n">a</span><span class="o">.</span><span class="n">assign_sub</span><span class="p">(</span><span class="n">gradients</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">learning_rate</span><span class="p">)</span> <span class="n">b</span><span class="o">.</span><span class="n">assign_sub</span><span class="p">(</span><span class="n">gradients</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">learning_rate</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"Curve fit weights: a = </span><span class="si">{</span><span class="n">a</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span><span class="si">}</span><span class="s2"> and b = </span><span class="si">{</span><span class="n">b</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span><span class="si">}</span><span class="s2">."</span><span class="p">)</span> <span class="c1"># We can now estimate the accuracy for pred_sample_size</span> <span class="n">max_acc</span> <span class="o">=</span> <span class="n">exp_func</span><span class="p">(</span><span class="n">pred_sample_size</span><span class="p">,</span> <span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">)</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> <span class="c1"># Print predicted x value and append to plot values</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"A model accuracy of </span><span class="si">{</span><span class="n">max_acc</span><span class="si">}</span><span class="s2"> is predicted for </span><span class="si">{</span><span class="n">pred_sample_size</span><span class="si">}</span><span class="s2"> samples."</span><span class="p">)</span> <span class="n">x_cont</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">linspace</span><span class="p">(</span><span class="n">x</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">pred_sample_size</span><span class="p">,</span> <span class="mi">100</span><span class="p">)</span> <span class="c1"># Build the plot</span> <span class="n">fig</span><span class="p">,</span> <span class="n">ax</span> <span class="o">=</span> <span class="n">plt</span><span class="o">.</span><span class="n">subplots</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">12</span><span class="p">,</span> <span class="mi">6</span><span class="p">))</span> <span class="n">ax</span><span class="o">.</span><span class="n">errorbar</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">mean_acc</span><span class="p">,</span> <span class="n">yerr</span><span class="o">=</span><span class="n">error</span><span class="p">,</span> <span class="n">fmt</span><span class="o">=</span><span class="s2">"o"</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s2">"Mean acc & std dev."</span><span class="p">)</span> <span class="n">ax</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">x_cont</span><span class="p">,</span> <span class="n">exp_func</span><span class="p">(</span><span class="n">x_cont</span><span class="p">,</span> <span class="n">a</span><span class="p">,</span> <span class="n">b</span><span class="p">),</span> <span class="s2">"r-"</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s2">"Fitted exponential curve."</span><span class="p">)</span> <span class="n">ax</span><span class="o">.</span><span class="n">set_ylabel</span><span class="p">(</span><span class="s2">"Model classification accuracy."</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">12</span><span class="p">)</span> <span class="n">ax</span><span class="o">.</span><span class="n">set_xlabel</span><span class="p">(</span><span class="s2">"Training sample size."</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">12</span><span class="p">)</span> <span class="n">ax</span><span class="o">.</span><span class="n">set_xticks</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">pred_sample_size</span><span class="p">))</span> <span class="n">ax</span><span class="o">.</span><span class="n">set_yticks</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">mean_acc</span><span class="p">,</span> <span class="n">max_acc</span><span class="p">))</span> <span class="n">ax</span><span class="o">.</span><span class="n">set_xticklabels</span><span class="p">(</span><span class="nb">list</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">pred_sample_size</span><span class="p">)),</span> <span class="n">rotation</span><span class="o">=</span><span class="mi">90</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span> <span class="n">ax</span><span class="o">.</span><span class="n">yaxis</span><span class="o">.</span><span class="n">set_tick_params</span><span class="p">(</span><span class="n">labelsize</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span> <span class="n">ax</span><span class="o">.</span><span class="n">set_title</span><span class="p">(</span><span class="s2">"Learning curve: model accuracy vs sample size."</span><span class="p">,</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">14</span><span class="p">)</span> <span class="n">ax</span><span class="o">.</span><span class="n">legend</span><span class="p">(</span><span class="n">loc</span><span class="o">=</span><span class="p">(</span><span class="mf">0.75</span><span class="p">,</span> <span class="mf">0.75</span><span class="p">),</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span> <span class="n">ax</span><span class="o">.</span><span class="n">xaxis</span><span class="o">.</span><span class="n">grid</span><span class="p">(</span><span class="kc">True</span><span class="p">)</span> <span class="n">ax</span><span class="o">.</span><span class="n">yaxis</span><span class="o">.</span><span class="n">grid</span><span class="p">(</span><span class="kc">True</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">tight_layout</span><span class="p">()</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> <span class="c1"># The mean absolute error (MAE) is calculated for curve fit to see how well</span> <span class="c1"># it fits the data. The lower the error the better the fit.</span> <span class="n">mae</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">losses</span><span class="o">.</span><span class="n">MeanAbsoluteError</span><span class="p">()</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"The mae for the curve fit is </span><span class="si">{</span><span class="n">mae</span><span class="p">(</span><span class="n">mean_acc</span><span class="p">,</span><span class="w"> </span><span class="n">exp_func</span><span class="p">(</span><span class="n">x</span><span class="p">,</span><span class="w"> </span><span class="n">a</span><span class="p">,</span><span class="w"> </span><span class="n">b</span><span class="p">))</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span><span class="si">}</span><span class="s2">."</span><span class="p">)</span> <span class="c1"># We use the whole training set to predict the model accuracy</span> <span class="n">fit_and_predict</span><span class="p">(</span><span class="n">train_acc</span><span class="p">,</span> <span class="n">sample_sizes</span><span class="p">,</span> <span class="n">pred_sample_size</span><span class="o">=</span><span class="n">num_train_samples</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Curve fit weights: a = 0.6445642113685608 and b = 0.048097413033246994. A model accuracy of 0.9517362117767334 is predicted for 3303 samples. </code></pre></div> </div> <p><img alt="png" src="/img/examples/keras_recipes/sample_size_estimate/sample_size_estimate_17_1.png" /></p> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>The mae for the curve fit is 0.016098767518997192. </code></pre></div> </div> <p>From the extrapolated curve we can see that 3303 images will yield an estimated accuracy of about 95%.</p> <p>Now, let's use all the data (3303 images) and train the model to see if our prediction was accurate!</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Now train the model with full dataset to get the actual accuracy</span> <span class="n">accuracy</span> <span class="o">=</span> <span class="n">train_model</span><span class="p">(</span><span class="n">img_train</span><span class="p">,</span> <span class="n">label_train</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">"A model accuracy of </span><span class="si">{</span><span class="n">accuracy</span><span class="si">}</span><span class="s2"> is reached on </span><span class="si">{</span><span class="n">num_train_samples</span><span class="si">}</span><span class="s2"> images!"</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>/var/folders/8n/8w8cqnvj01xd4ghznl11nyn000_93_/T/ipykernel_30919/1838736464.py:16: UserWarning: `input_shape` is undefined or non-square, or `rows` is not in [96, 128, 160, 192, 224]. Weights for input shape (224, 224) will be loaded as the default. model = keras.applications.MobileNetV2( Trainable weights: 2 Non_trainable weights: 260 Epoch 1/10 47/47 ━━━━━━━━━━━━━━━━━━━━ 18s 338ms/step - acc: 0.4305 - auc: 0.7221 - loss: 1.4585 - val_acc: 0.8218 - val_auc: 0.9700 - val_loss: 0.5043 Epoch 2/10 47/47 ━━━━━━━━━━━━━━━━━━━━ 15s 326ms/step - acc: 0.7666 - auc: 0.9504 - loss: 0.6287 - val_acc: 0.8792 - val_auc: 0.9838 - val_loss: 0.3733 Epoch 3/10 47/47 ━━━━━━━━━━━━━━━━━━━━ 16s 332ms/step - acc: 0.8252 - auc: 0.9673 - loss: 0.5039 - val_acc: 0.8852 - val_auc: 0.9880 - val_loss: 0.3182 Epoch 4/10 47/47 ━━━━━━━━━━━━━━━━━━━━ 16s 348ms/step - acc: 0.8458 - auc: 0.9768 - loss: 0.4264 - val_acc: 0.8822 - val_auc: 0.9893 - val_loss: 0.2956 Epoch 5/10 47/47 ━━━━━━━━━━━━━━━━━━━━ 16s 350ms/step - acc: 0.8661 - auc: 0.9812 - loss: 0.3821 - val_acc: 0.8912 - val_auc: 0.9903 - val_loss: 0.2755 Epoch 6/10 47/47 ━━━━━━━━━━━━━━━━━━━━ 16s 336ms/step - acc: 0.8656 - auc: 0.9836 - loss: 0.3555 - val_acc: 0.9003 - val_auc: 0.9906 - val_loss: 0.2701 Epoch 7/10 47/47 ━━━━━━━━━━━━━━━━━━━━ 16s 331ms/step - acc: 0.8800 - auc: 0.9846 - loss: 0.3430 - val_acc: 0.8943 - val_auc: 0.9914 - val_loss: 0.2548 Epoch 8/10 47/47 ━━━━━━━━━━━━━━━━━━━━ 16s 333ms/step - acc: 0.8917 - auc: 0.9871 - loss: 0.3143 - val_acc: 0.8973 - val_auc: 0.9917 - val_loss: 0.2494 Epoch 9/10 47/47 ━━━━━━━━━━━━━━━━━━━━ 15s 320ms/step - acc: 0.9003 - auc: 0.9891 - loss: 0.2906 - val_acc: 0.9063 - val_auc: 0.9908 - val_loss: 0.2463 Epoch 10/10 47/47 ━━━━━━━━━━━━━━━━━━━━ 15s 324ms/step - acc: 0.8997 - auc: 0.9895 - loss: 0.2839 - val_acc: 0.9124 - val_auc: 0.9912 - val_loss: 0.2394 Trainable weights: 24 Non-trainable weights: 238 Epoch 1/29 47/47 ━━━━━━━━━━━━━━━━━━━━ 27s 537ms/step - acc: 0.8457 - auc: 0.9747 - loss: 0.4365 - val_acc: 0.9094 - val_auc: 0.9916 - val_loss: 0.2692 Epoch 2/29 47/47 ━━━━━━━━━━━━━━━━━━━━ 24s 502ms/step - acc: 0.9223 - auc: 0.9932 - loss: 0.2198 - val_acc: 0.9033 - val_auc: 0.9891 - val_loss: 0.2826 Epoch 3/29 47/47 ━━━━━━━━━━━━━━━━━━━━ 25s 534ms/step - acc: 0.9499 - auc: 0.9972 - loss: 0.1399 - val_acc: 0.9003 - val_auc: 0.9910 - val_loss: 0.2804 Epoch 4/29 47/47 ━━━━━━━━━━━━━━━━━━━━ 26s 554ms/step - acc: 0.9590 - auc: 0.9983 - loss: 0.1130 - val_acc: 0.9396 - val_auc: 0.9968 - val_loss: 0.1510 Epoch 5/29 47/47 ━━━━━━━━━━━━━━━━━━━━ 25s 533ms/step - acc: 0.9805 - auc: 0.9996 - loss: 0.0538 - val_acc: 0.9486 - val_auc: 0.9914 - val_loss: 0.1795 Epoch 6/29 47/47 ━━━━━━━━━━━━━━━━━━━━ 24s 516ms/step - acc: 0.9949 - auc: 1.0000 - loss: 0.0226 - val_acc: 0.9124 - val_auc: 0.9833 - val_loss: 0.3186 Epoch 7/29 47/47 ━━━━━━━━━━━━━━━━━━━━ 25s 534ms/step - acc: 0.9900 - auc: 0.9999 - loss: 0.0297 - val_acc: 0.9275 - val_auc: 0.9881 - val_loss: 0.3017 Epoch 8/29 47/47 ━━━━━━━━━━━━━━━━━━━━ 25s 536ms/step - acc: 0.9910 - auc: 0.9999 - loss: 0.0228 - val_acc: 0.9426 - val_auc: 0.9927 - val_loss: 0.1938 Epoch 9/29 47/47 ━━━━━━━━━━━━━━━━━━━━ 0s 489ms/step - acc: 0.9995 - auc: 1.0000 - loss: 0.0069Restoring model weights from the end of the best epoch: 4. 47/47 ━━━━━━━━━━━━━━━━━━━━ 25s 527ms/step - acc: 0.9995 - auc: 1.0000 - loss: 0.0068 - val_acc: 0.9426 - val_auc: 0.9919 - val_loss: 0.2957 Epoch 9: early stopping 12/12 ━━━━━━━━━━━━━━━━━━━━ 2s 170ms/step - acc: 0.9641 - auc: 0.9972 - loss: 0.1264 A model accuracy of 0.9964 is reached on 3303 images! </code></pre></div> </div> <hr /> <h2 id="conclusion">Conclusion</h2> <p>We see that a model accuracy of about 94-96%* is reached using 3303 images. This is quite close to our estimate!</p> <p>Even though we used only 50% of the dataset (1651 images) we were able to model the training behaviour of our model and predict the model accuracy for a given amount of images. This same methodology can be used to predict the amount of images needed to reach a desired accuracy. This is very useful when a smaller set of data is available, and it has been shown that convergence on a deep learning model is possible, but more images are needed. The image count prediction can be used to plan and budget for further image collection initiatives.</p> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#estimating-required-sample-size-for-model-training'>Estimating required sample size for model training</a> </div> <div class='k-outline-depth-1'> <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#setup'>Setup</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#load-tensorflow-dataset-and-convert-to-numpy-arrays'>Load TensorFlow dataset and convert to NumPy arrays</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#plot-a-few-examples-from-the-test-set'>Plot a few examples from the test set</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#augmentation'>Augmentation</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#define-model-building-amp-training-functions'>Define model building & training functions</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#define-iterative-training-function'>Define iterative training function</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#train-models-iteratively'>Train models iteratively</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#learning-curve'>Learning curve</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#conclusion'>Conclusion</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>