CINXE.COM

Keras FAQ

<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/getting_started/faq/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Keras FAQ"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Keras FAQ"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Keras FAQ</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link active" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-sublink" href="/getting_started/intro_to_keras_for_engineers/">Introduction to Keras for engineers</a> <a class="nav-sublink" href="/getting_started/benchmarks/">Keras 3 benchmarks</a> <a class="nav-sublink" href="/getting_started/ecosystem/">The Keras ecosystem</a> <a class="nav-sublink active" href="/getting_started/faq/">Frequently Asked Questions</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparam Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/getting_started/'>Getting started</a> / Keras FAQ </div> <div class='k-content'> <h1 id="keras-faq">Keras FAQ</h1> <p>A list of frequently Asked Keras Questions.</p> <h2 id="general-questions">General questions</h2> <ul> <li><a href="#how-can-i-train-a-keras-model-on-multiple-gpus-on-a-single-machine">How can I train a Keras model on multiple GPUs (on a single machine)?</a></li> <li><a href="#how-can-i-train-a-keras-model-on-tpu">How can I train a Keras model on TPU?</a></li> <li><a href="#where-is-the-keras-configuration-file-stored">Where is the Keras configuration file stored?</a></li> <li><a href="#how-to-do-hyperparameter-tuning-with-keras">How to do hyperparameter tuning with Keras?</a></li> <li><a href="#how-can-i-obtain-reproducible-results-using-keras-during-development">How can I obtain reproducible results using Keras during development?</a></li> <li><a href="#what-are-my-options-for-saving-models">What are my options for saving models?</a></li> <li><a href="#how-can-i-install-hdf5-or-h5py-to-save-my-models">How can I install HDF5 or h5py to save my models?</a></li> <li><a href="#how-should-i-cite-keras">How should I cite Keras?</a></li> </ul> <h2 id="trainingrelated-questions">Training-related questions</h2> <ul> <li><a href="#what-do-sample-batch-and-epoch-mean">What do "sample", "batch", and "epoch" mean?</a></li> <li><a href="#why-is-my-training-loss-much-higher-than-my-testing-loss">Why is my training loss much higher than my testing loss?</a></li> <li><a href="#how-can-i-ensure-my-training-run-can-recover-from-program-interruptions">How can I ensure my training run can recover from program interruptions?</a></li> <li><a href="#how-can-i-interrupt-training-when-the-validation-loss-isnt-decreasing-anymore">How can I interrupt training when the validation loss isn't decreasing anymore?</a></li> <li><a href="#how-can-i-freeze-layers-and-do-finetuning">How can I freeze layers and do fine-tuning?</a></li> <li><a href="#whats-the-difference-between-the-training-argument-in-call-and-the-trainable-attribute">What's the difference between the <code>training</code> argument in <code>call()</code> and the <code>trainable</code> attribute?</a></li> <li><a href="#in-fit-how-is-the-validation-split-computed">In <code>fit()</code>, how is the validation split computed?</a></li> <li><a href="#in-fit-is-the-data-shuffled-during-training">In <code>fit()</code>, is the data shuffled during training?</a></li> <li><a href="#whats-the-recommended-way-to-monitor-my-metrics-when-training-with-fit">What's the recommended way to monitor my metrics when training with <code>fit()</code>?</a></li> <li><a href="#what-if-i-need-to-customize-what-fit-does">What if I need to customize what <code>fit()</code> does?</a></li> <li><a href="#whats-the-difference-between-model-methods-predict-and-call">What's the difference between <code>Model</code> methods <code>predict()</code> and <code>__call__()</code>?</a></li> </ul> <h2 id="modelingrelated-questions">Modeling-related questions</h2> <ul> <li><a href="#how-can-i-obtain-the-output-of-an-intermediate-layer-feature-extraction">How can I obtain the output of an intermediate layer (feature extraction)?</a></li> <li><a href="#how-can-i-use-pre-trained-models-in-keras">How can I use pre-trained models in Keras?</a></li> <li><a href="#how-can-i-use-stateful-rnns">How can I use stateful RNNs?</a></li> </ul> <hr /> <h2 id="general-questions">General questions</h2> <h3 id="how-can-i-train-a-keras-model-on-multiple-gpus-on-a-single-machine">How can I train a Keras model on multiple GPUs (on a single machine)?</h3> <p>There are two ways to run a single model on multiple GPUs: <strong>data parallelism</strong> and <strong>device parallelism</strong>. Keras covers both.</p> <p>For data parallelism, Keras supports the built-in data parallel distribution APIs of JAX, TensorFlow, and PyTorch. See the following guides:</p> <ul> <li><a href="/guides/distributed_training_with_jax/">Multi-GPU distributed training with JAX</a></li> <li><a href="/guides/distributed_training_with_tensorflow/">Multi-GPU distributed training with TensorFlow</a></li> <li><a href="/guides/distributed_training_with_torch/">Multi-GPU distributed training with PyTorch</a></li> </ul> <p>For model parallelism, Keras has its own distribution API, which is currently only support by the JAX backend. See <a href="/api/distribution/">the documentation for the <code>LayoutMap</code> API</a>.</p> <hr /> <h3 id="how-can-i-train-a-keras-model-on-tpu">How can I train a Keras model on TPU?</h3> <p>TPUs are a fast &amp; efficient hardware accelerator for deep learning that is publicly available on Google Cloud. You can use TPUs via Colab, Kaggle notebooks, and GCP Deep Learning VMs (provided the <code>TPU_NAME</code> environment variable is set on the VM).</p> <p>All Keras backends (JAX, TensorFlow, PyTorch) are supported on TPU, but we recommend JAX or TensorFlow in this case.</p> <p><strong>Using JAX:</strong></p> <p>When connected to a TPU runtime, just insert this code snippet before model construction:</p> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">jax</span> <span class="n">distribution</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">distribution</span><span class="o">.</span><span class="n">DataParallel</span><span class="p">(</span><span class="n">devices</span><span class="o">=</span><span class="n">jax</span><span class="o">.</span><span class="n">devices</span><span class="p">())</span> <span class="n">keras</span><span class="o">.</span><span class="n">distribution</span><span class="o">.</span><span class="n">set_distribution</span><span class="p">(</span><span class="n">distribution</span><span class="p">)</span> </code></pre></div> <p><strong>Using TensorFlow:</strong></p> <p>When connected to a TPU runtime, use <code>TPUClusterResolver</code> to detect the TPU. Then, create <code>TPUStrategy</code> and construct your model in the strategy scope:</p> <div class="codehilite"><pre><span></span><code><span class="k">try</span><span class="p">:</span> <span class="n">tpu</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">distribute</span><span class="o">.</span><span class="n">cluster_resolver</span><span class="o">.</span><span class="n">TPUClusterResolver</span><span class="o">.</span><span class="n">connect</span><span class="p">()</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Device:&quot;</span><span class="p">,</span> <span class="n">tpu</span><span class="o">.</span><span class="n">master</span><span class="p">())</span> <span class="n">strategy</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">distribute</span><span class="o">.</span><span class="n">TPUStrategy</span><span class="p">(</span><span class="n">tpu</span><span class="p">)</span> <span class="k">except</span><span class="p">:</span> <span class="n">strategy</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">distribute</span><span class="o">.</span><span class="n">get_strategy</span><span class="p">()</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Number of replicas:&quot;</span><span class="p">,</span> <span class="n">strategy</span><span class="o">.</span><span class="n">num_replicas_in_sync</span><span class="p">)</span> <span class="k">with</span> <span class="n">strategy</span><span class="o">.</span><span class="n">scope</span><span class="p">():</span> <span class="c1"># Create your model here.</span> <span class="o">...</span> </code></pre></div> <p>Importantly, you should:</p> <ul> <li>Make sure you are able to read your data fast enough to keep the TPU utilized.</li> <li>Consider running multiple steps of gradient descent per graph execution in order to keep the TPU utilized. You can do this via the <code>experimental_steps_per_execution</code> argument <code>compile()</code>. It will yield a significant speed up for small models.</li> </ul> <hr /> <h3 id="where-is-the-keras-configuration-file-stored">Where is the Keras configuration file stored?</h3> <p>The default directory where all Keras data is stored is:</p> <p><code>$HOME/.keras/</code></p> <p>For instance, for me, on a MacBook Pro, it's <code>/Users/fchollet/.keras/</code>.</p> <p>Note that Windows users should replace <code>$HOME</code> with <code>%USERPROFILE%</code>.</p> <p>In case Keras cannot create the above directory (e.g. due to permission issues), <code>/tmp/.keras/</code> is used as a backup.</p> <p>The Keras configuration file is a JSON file stored at <code>$HOME/.keras/keras.json</code>. The default configuration file looks like this:</p> <div class="codehilite"><pre><span></span><code>{ &quot;image_data_format&quot;: &quot;channels_last&quot;, &quot;epsilon&quot;: 1e-07, &quot;floatx&quot;: &quot;float32&quot;, &quot;backend&quot;: &quot;tensorflow&quot; } </code></pre></div> <p>It contains the following fields:</p> <ul> <li>The image data format to be used as default by image processing layers and utilities (either <code>channels_last</code> or <code>channels_first</code>).</li> <li>The <code>epsilon</code> numerical fuzz factor to be used to prevent division by zero in some operations.</li> <li>The default float data type.</li> <li>The default backend. It can be one of <code>"jax"</code>, <code>"tensorflow"</code>, <code>"torch"</code>, or <code>"numpy"</code>.</li> </ul> <p>Likewise, cached dataset files, such as those downloaded with <a href="/utils/#get_file"><code>get_file()</code></a>, are stored by default in <code>$HOME/.keras/datasets/</code>, and cached model weights files from Keras Applications are stored by default in <code>$HOME/.keras/models/</code>.</p> <hr /> <h3 id="how-to-do-hyperparameter-tuning-with-keras">How to do hyperparameter tuning with Keras?</h3> <p>We recommend using <a href="https://keras.io/keras_tuner/">KerasTuner</a>.</p> <hr /> <h3 id="how-can-i-obtain-reproducible-results-using-keras-during-development">How can I obtain reproducible results using Keras during development?</h3> <p>There are four sources of randomness to consider:</p> <ol> <li>Keras itself (e.g. <code>keras.random</code> ops or random layers from <code>keras.layers</code>).</li> <li>The current Keras backend (e.g. JAX, TensorFlow, or PyTorch).</li> <li>The Python runtime.</li> <li>The CUDA runtime. When running on a GPU, some operations have non-deterministic outputs. This is due to the fact that GPUs run many operations in parallel, so the order of execution is not always guaranteed. Due to the limited precision of floats, even adding several numbers together may give slightly different results depending on the order in which you add them.</li> </ol> <p>To make both Keras and the current backend framework deterministic, use this:</p> <div class="codehilite"><pre><span></span><code><span class="n">keras</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">set_random_seed</span><span class="p">(</span><span class="mi">1337</span><span class="p">)</span> </code></pre></div> <p>To make Python deterministic, you need to set the <code>PYTHONHASHSEED</code> environment variable to <code>0</code> before the program starts (not within the program itself). This is necessary in Python 3.2.3 onwards to have reproducible behavior for certain hash-based operations (e.g., the item order in a set or a dict, see <a href="https://docs.python.org/3.7/using/cmdline.html#envvar-PYTHONHASHSEED">Python's documentation</a>).</p> <p>To make the CUDA runtime deterministic: if using the TensorFlow backend, call <a href="https://www.tensorflow.org/api_docs/python/tf/config/experimental/enable_op_determinism"><code>tf.config.experimental.enable_op_determinism</code></a>. Note that this will have a performance cost. What to do for other backends may vary &ndash; check the documentation of your backend framework directly.</p> <hr /> <h3 id="what-are-my-options-for-saving-models">What are my options for saving models?</h3> <p><em>Note: it is not recommended to use pickle or cPickle to save a Keras model.</em></p> <p><strong>1) Whole-model saving (configuration + weights)</strong></p> <p>Whole-model saving means creating a file that will contain:</p> <ul> <li>the architecture of the model, allowing you to re-create the model</li> <li>the weights of the model</li> <li>the training configuration (loss, optimizer)</li> <li>the state of the optimizer, allowing you to resume training exactly where you left off.</li> </ul> <p>The default and recommended way to save a whole model is to just do: <code>model.save(your_file_path.keras)</code>.</p> <p>After saving a model in either format, you can reinstantiate it via <code>model = keras.models.load_model(your_file_path.keras)</code>.</p> <p><strong>Example:</strong></p> <div class="codehilite"><pre><span></span><code><span class="kn">from</span> <span class="nn">keras.saving</span> <span class="kn">import</span> <span class="n">load_model</span> <span class="n">model</span><span class="o">.</span><span class="n">save</span><span class="p">(</span><span class="s1">&#39;my_model.keras&#39;</span><span class="p">)</span> <span class="k">del</span> <span class="n">model</span> <span class="c1"># deletes the existing model</span> <span class="c1"># returns a compiled model</span> <span class="c1"># identical to the previous one</span> <span class="n">model</span> <span class="o">=</span> <span class="n">load_model</span><span class="p">(</span><span class="s1">&#39;my_model.keras&#39;</span><span class="p">)</span> </code></pre></div> <p><strong>2) Weights-only saving</strong></p> <p>If you need to save the <strong>weights of a model</strong>, you can do so in HDF5 with the code below, using the file extension <code>.weights.h5</code>:</p> <div class="codehilite"><pre><span></span><code><span class="n">model</span><span class="o">.</span><span class="n">save_weights</span><span class="p">(</span><span class="s1">&#39;my_model.weights.h5&#39;</span><span class="p">)</span> </code></pre></div> <p>Assuming you have code for instantiating your model, you can then load the weights you saved into a model with the <em>same</em> architecture:</p> <div class="codehilite"><pre><span></span><code><span class="n">model</span><span class="o">.</span><span class="n">load_weights</span><span class="p">(</span><span class="s1">&#39;my_model.weights.h5&#39;</span><span class="p">)</span> </code></pre></div> <p>If you need to load the weights into a <em>different</em> architecture (with some layers in common), for instance for fine-tuning or transfer-learning, you can load them by <em>layer name</em>:</p> <div class="codehilite"><pre><span></span><code><span class="n">model</span><span class="o">.</span><span class="n">load_weights</span><span class="p">(</span><span class="s1">&#39;my_model.weights.h5&#39;</span><span class="p">,</span> <span class="n">by_name</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> </code></pre></div> <p>Example:</p> <div class="codehilite"><pre><span></span><code><span class="sd">&quot;&quot;&quot;</span> <span class="sd">Assuming the original model looks like this:</span> <span class="sd">model = Sequential()</span> <span class="sd">model.add(Dense(2, input_dim=3, name=&#39;dense_1&#39;))</span> <span class="sd">model.add(Dense(3, name=&#39;dense_2&#39;))</span> <span class="sd">...</span> <span class="sd">model.save_weights(fname)</span> <span class="sd">&quot;&quot;&quot;</span> <span class="c1"># new model</span> <span class="n">model</span> <span class="o">=</span> <span class="n">Sequential</span><span class="p">()</span> <span class="n">model</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">Dense</span><span class="p">(</span><span class="mi">2</span><span class="p">,</span> <span class="n">input_dim</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">&#39;dense_1&#39;</span><span class="p">))</span> <span class="c1"># will be loaded</span> <span class="n">model</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">Dense</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">&#39;new_dense&#39;</span><span class="p">))</span> <span class="c1"># will not be loaded</span> <span class="c1"># load weights from the first model; will only affect the first layer, dense_1.</span> <span class="n">model</span><span class="o">.</span><span class="n">load_weights</span><span class="p">(</span><span class="n">fname</span><span class="p">,</span> <span class="n">by_name</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> </code></pre></div> <p>See also <a href="#how-can-i-install-hdf5-or-h5py-to-save-my-models">How can I install HDF5 or h5py to save my models?</a> for instructions on how to install <code>h5py</code>.</p> <p><strong>3) Configuration-only saving (serialization)</strong></p> <p>If you only need to save the <strong>architecture of a model</strong>, and not its weights or its training configuration, you can do:</p> <div class="codehilite"><pre><span></span><code><span class="c1"># save as JSON</span> <span class="n">json_string</span> <span class="o">=</span> <span class="n">model</span><span class="o">.</span><span class="n">to_json</span><span class="p">()</span> </code></pre></div> <p>The generated JSON file is human-readable and can be manually edited if needed.</p> <p>You can then build a fresh model from this data:</p> <div class="codehilite"><pre><span></span><code><span class="c1"># model reconstruction from JSON:</span> <span class="kn">from</span> <span class="nn">keras.models</span> <span class="kn">import</span> <span class="n">model_from_json</span> <span class="n">model</span> <span class="o">=</span> <span class="n">model_from_json</span><span class="p">(</span><span class="n">json_string</span><span class="p">)</span> </code></pre></div> <p><strong>4) Handling custom layers (or other custom objects) in saved models</strong></p> <p>If the model you want to load includes custom layers or other custom classes or functions, you can pass them to the loading mechanism via the <code>custom_objects</code> argument:</p> <div class="codehilite"><pre><span></span><code><span class="kn">from</span> <span class="nn">keras.models</span> <span class="kn">import</span> <span class="n">load_model</span> <span class="c1"># Assuming your model includes instance of an &quot;AttentionLayer&quot; class</span> <span class="n">model</span> <span class="o">=</span> <span class="n">load_model</span><span class="p">(</span><span class="s1">&#39;my_model.h5&#39;</span><span class="p">,</span> <span class="n">custom_objects</span><span class="o">=</span><span class="p">{</span><span class="s1">&#39;AttentionLayer&#39;</span><span class="p">:</span> <span class="n">AttentionLayer</span><span class="p">})</span> </code></pre></div> <p>Alternatively, you can use a <a href="https://keras.io/utils/#customobjectscope">custom object scope</a>:</p> <div class="codehilite"><pre><span></span><code><span class="kn">from</span> <span class="nn">keras.utils</span> <span class="kn">import</span> <span class="n">CustomObjectScope</span> <span class="k">with</span> <span class="n">CustomObjectScope</span><span class="p">({</span><span class="s1">&#39;AttentionLayer&#39;</span><span class="p">:</span> <span class="n">AttentionLayer</span><span class="p">}):</span> <span class="n">model</span> <span class="o">=</span> <span class="n">load_model</span><span class="p">(</span><span class="s1">&#39;my_model.h5&#39;</span><span class="p">)</span> </code></pre></div> <p>Custom objects handling works the same way for <code>load_model</code> &amp; <code>model_from_json</code>:</p> <div class="codehilite"><pre><span></span><code><span class="kn">from</span> <span class="nn">keras.models</span> <span class="kn">import</span> <span class="n">model_from_json</span> <span class="n">model</span> <span class="o">=</span> <span class="n">model_from_json</span><span class="p">(</span><span class="n">json_string</span><span class="p">,</span> <span class="n">custom_objects</span><span class="o">=</span><span class="p">{</span><span class="s1">&#39;AttentionLayer&#39;</span><span class="p">:</span> <span class="n">AttentionLayer</span><span class="p">})</span> </code></pre></div> <hr /> <h3 id="how-can-i-install-hdf5-or-h5py-to-save-my-models">How can I install HDF5 or h5py to save my models?</h3> <p>In order to save your Keras models as HDF5 files, Keras uses the h5py Python package. It is a dependency of Keras and should be installed by default. On Debian-based distributions, you will have to additionally install <code>libhdf5</code>:</p> <div class="k-default-code-block"> <div class="codehilite"><pre><span></span><code>sudo apt-get install libhdf5-serial-dev </code></pre></div> </div> <p>If you are unsure if h5py is installed you can open a Python shell and load the module via</p> <div class="codehilite"><pre><span></span><code>import h5py </code></pre></div> <p>If it imports without error it is installed, otherwise you can find <a href="http://docs.h5py.org/en/latest/build.html">detailed installation instructions here</a>.</p> <hr /> <h3 id="how-should-i-cite-keras">How should I cite Keras?</h3> <p>Please cite Keras in your publications if it helps your research. Here is an example BibTeX entry:</p> <p><code style="color: gray;"> @misc{chollet2015keras,<br> &nbsp;&nbsp;title={Keras},<br> &nbsp;&nbsp;author={Chollet, Fran\c{c}ois and others},<br> &nbsp;&nbsp;year={2015},<br> &nbsp;&nbsp;howpublished={\url{https://keras.io}},<br> } </code></p> <hr /> <h2 id="trainingrelated-questions">Training-related questions</h2> <h3 id="what-do-sample-batch-and-epoch-mean">What do "sample", "batch", and "epoch" mean?</h3> <p>Below are some common definitions that are necessary to know and understand to correctly utilize Keras <code>fit()</code>:</p> <ul> <li><strong>Sample</strong>: one element of a dataset. For instance, one image is a <strong>sample</strong> in a convolutional network. One audio snippet is a <strong>sample</strong> for a speech recognition model.</li> </ul> <ul> <li><strong>Batch</strong>: a set of <em>N</em> samples. The samples in a <strong>batch</strong> are processed independently, in parallel. If training, a batch results in only one update to the model. A <strong>batch</strong> generally approximates the distribution of the input data better than a single input. The larger the batch, the better the approximation; however, it is also true that the batch will take longer to process and will still result in only one update. For inference (evaluate/predict), it is recommended to pick a batch size that is as large as you can afford without going out of memory (since larger batches will usually result in faster evaluation/prediction).</li> </ul> <ul> <li><strong>Epoch</strong>: an arbitrary cutoff, generally defined as "one pass over the entire dataset", used to separate training into distinct phases, which is useful for logging and periodic evaluation. When using <code>validation_data</code> or <code>validation_split</code> with the <code>fit</code> method of Keras models, evaluation will be run at the end of every <strong>epoch</strong>. Within Keras, there is the ability to add <a href="/api/callbacks/">callbacks</a> specifically designed to be run at the end of an <strong>epoch</strong>. Examples of these are learning rate changes and model checkpointing (saving).</li> </ul> <hr /> <h3 id="why-is-my-training-loss-much-higher-than-my-testing-loss">Why is my training loss much higher than my testing loss?</h3> <p>A Keras model has two modes: training and testing. Regularization mechanisms, such as Dropout and L1/L2 weight regularization, are turned off at testing time. They are reflected in the training time loss but not in the test time loss.</p> <p>Besides, the training loss that Keras displays is the average of the losses for each batch of training data, <strong>over the current epoch</strong>. Because your model is changing over time, the loss over the first batches of an epoch is generally higher than over the last batches. This can bring the epoch-wise average down. On the other hand, the testing loss for an epoch is computed using the model as it is at the end of the epoch, resulting in a lower loss.</p> <hr /> <h3 id="how-can-i-ensure-my-training-run-can-recover-from-program-interruptions">How can I ensure my training run can recover from program interruptions?</h3> <p>To ensure the ability to recover from an interrupted training run at any time (fault tolerance), you should use a <a href="/api/callbacks/backup_and_restore#backupandrestore-class"><code>keras.callbacks.BackupAndRestore</code></a> callback that regularly saves your training progress, including the epoch number and weights, to disk, and loads it the next time you call <code>Model.fit()</code>.</p> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> <span class="kn">import</span> <span class="nn">keras</span> <span class="k">class</span> <span class="nc">InterruptingCallback</span><span class="p">(</span><span class="n">keras</span><span class="o">.</span><span class="n">callbacks</span><span class="o">.</span><span class="n">Callback</span><span class="p">):</span> <span class="w"> </span><span class="sd">&quot;&quot;&quot;A callback to intentionally introduce interruption to training.&quot;&quot;&quot;</span> <span class="k">def</span> <span class="nf">on_epoch_end</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">epoch</span><span class="p">,</span> <span class="n">log</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span> <span class="k">if</span> <span class="n">epoch</span> <span class="o">==</span> <span class="mi">15</span><span class="p">:</span> <span class="k">raise</span> <span class="ne">RuntimeError</span><span class="p">(</span><span class="s1">&#39;Interruption&#39;</span><span class="p">)</span> <span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">([</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">10</span><span class="p">)])</span> <span class="n">optimizer</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">optimizers</span><span class="o">.</span><span class="n">SGD</span><span class="p">()</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">optimizer</span><span class="p">,</span> <span class="n">loss</span><span class="o">=</span><span class="s2">&quot;mse&quot;</span><span class="p">)</span> <span class="n">x</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">((</span><span class="mi">24</span><span class="p">,</span> <span class="mi">10</span><span class="p">))</span> <span class="n">y</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">((</span><span class="mi">24</span><span class="p">,))</span> <span class="n">backup_callback</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">callbacks</span><span class="o">.</span><span class="n">experimental</span><span class="o">.</span><span class="n">BackupAndRestore</span><span class="p">(</span> <span class="n">backup_dir</span><span class="o">=</span><span class="s1">&#39;/tmp/backup&#39;</span><span class="p">)</span> <span class="k">try</span><span class="p">:</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span> <span class="n">steps_per_epoch</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">callbacks</span><span class="o">=</span><span class="p">[</span><span class="n">backup_callback</span><span class="p">,</span> <span class="n">InterruptingCallback</span><span class="p">()])</span> <span class="k">except</span> <span class="ne">RuntimeError</span><span class="p">:</span> <span class="nb">print</span><span class="p">(</span><span class="s1">&#39;***Handling interruption***&#39;</span><span class="p">)</span> <span class="c1"># This continues at the epoch where it left off.</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">20</span><span class="p">,</span> <span class="n">steps_per_epoch</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span> <span class="n">callbacks</span><span class="o">=</span><span class="p">[</span><span class="n">backup_callback</span><span class="p">])</span> </code></pre></div> <p>Find out more in the <a href="/api/callbacks/">callbacks documentation</a>.</p> <hr /> <h3 id="how-can-i-interrupt-training-when-the-validation-loss-isnt-decreasing-anymore">How can I interrupt training when the validation loss isn't decreasing anymore?</h3> <p>You can use an <code>EarlyStopping</code> callback:</p> <div class="codehilite"><pre><span></span><code><span class="kn">from</span> <span class="nn">keras.callbacks</span> <span class="kn">import</span> <span class="n">EarlyStopping</span> <span class="n">early_stopping</span> <span class="o">=</span> <span class="n">EarlyStopping</span><span class="p">(</span><span class="n">monitor</span><span class="o">=</span><span class="s1">&#39;val_loss&#39;</span><span class="p">,</span> <span class="n">patience</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">y</span><span class="p">,</span> <span class="n">validation_split</span><span class="o">=</span><span class="mf">0.2</span><span class="p">,</span> <span class="n">callbacks</span><span class="o">=</span><span class="p">[</span><span class="n">early_stopping</span><span class="p">])</span> </code></pre></div> <p>Find out more in the <a href="/api/callbacks/">callbacks documentation</a>.</p> <hr /> <h3 id="how-can-i-freeze-layers-and-do-finetuning">How can I freeze layers and do fine-tuning?</h3> <p><strong>Setting the <code>trainable</code> attribute</strong></p> <p>All layers &amp; models have a <code>layer.trainable</code> boolean attribute:</p> <div class="codehilite"><pre><span></span><code>&gt;&gt;&gt;<span class="w"> </span><span class="nv">layer</span><span class="w"> </span><span class="o">=</span><span class="w"> </span>Dense<span class="o">(</span><span class="m">3</span><span class="o">)</span> &gt;&gt;&gt;<span class="w"> </span>layer.trainable True </code></pre></div> <p>On all layers &amp; models, the <code>trainable</code> attribute can be set (to True or False). When set to <code>False</code>, the <code>layer.trainable_weights</code> attribute is empty:</p> <div class="codehilite"><pre><span></span><code><span class="o">&gt;&gt;&gt;</span> <span class="n">layer</span> <span class="o">=</span> <span class="n">Dense</span><span class="p">(</span><span class="mi">3</span><span class="p">)</span> <span class="o">&gt;&gt;&gt;</span> <span class="n">layer</span><span class="o">.</span><span class="n">build</span><span class="p">(</span><span class="n">input_shape</span><span class="o">=</span><span class="p">(</span><span class="kc">None</span><span class="p">,</span> <span class="mi">3</span><span class="p">))</span> <span class="c1"># Create the weights of the layer</span> <span class="o">&gt;&gt;&gt;</span> <span class="n">layer</span><span class="o">.</span><span class="n">trainable</span> <span class="kc">True</span> <span class="o">&gt;&gt;&gt;</span> <span class="n">layer</span><span class="o">.</span><span class="n">trainable_weights</span> <span class="p">[</span><span class="o">&lt;</span><span class="n">KerasVariable</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">float32</span><span class="p">,</span> <span class="n">path</span><span class="o">=</span><span class="n">dense</span><span class="o">/</span><span class="n">kernel</span><span class="o">&gt;</span><span class="p">,</span> <span class="o">&lt;</span><span class="n">KerasVariable</span> <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="mi">3</span><span class="p">,),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">float32</span><span class="p">,</span> <span class="n">path</span><span class="o">=</span><span class="n">dense</span><span class="o">/</span><span class="n">bias</span><span class="o">&gt;</span><span class="p">]</span> <span class="o">&gt;&gt;&gt;</span> <span class="n">layer</span><span class="o">.</span><span class="n">trainable</span> <span class="o">=</span> <span class="kc">False</span> <span class="o">&gt;&gt;&gt;</span> <span class="n">layer</span><span class="o">.</span><span class="n">trainable_weights</span> <span class="p">[]</span> </code></pre></div> <p>Setting the <code>trainable</code> attribute on a layer recursively sets it on all children layers (contents of <code>self.layers</code>).</p> <p><strong>1) When training with <code>fit()</code>:</strong></p> <p>To do fine-tuning with <code>fit()</code>, you would:</p> <ul> <li>Instantiate a base model and load pre-trained weights</li> <li>Freeze that base model</li> <li>Add trainable layers on top</li> <li>Call <code>compile()</code> and <code>fit()</code></li> </ul> <p>Like this:</p> <div class="codehilite"><pre><span></span><code><span class="n">model</span> <span class="o">=</span> <span class="n">Sequential</span><span class="p">([</span> <span class="n">ResNet50Base</span><span class="p">(</span><span class="n">input_shape</span><span class="o">=</span><span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">weights</span><span class="o">=</span><span class="s1">&#39;pretrained&#39;</span><span class="p">),</span> <span class="n">Dense</span><span class="p">(</span><span class="mi">10</span><span class="p">),</span> <span class="p">])</span> <span class="n">model</span><span class="o">.</span><span class="n">layers</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">trainable</span> <span class="o">=</span> <span class="kc">False</span> <span class="c1"># Freeze ResNet50Base.</span> <span class="k">assert</span> <span class="n">model</span><span class="o">.</span><span class="n">layers</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">trainable_weights</span> <span class="o">==</span> <span class="p">[]</span> <span class="c1"># ResNet50Base has no trainable weights.</span> <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">trainable_weights</span><span class="p">)</span> <span class="o">==</span> <span class="mi">2</span> <span class="c1"># Just the bias &amp; kernel of the Dense layer.</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="o">...</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="o">...</span><span class="p">)</span> <span class="c1"># Train Dense while excluding ResNet50Base.</span> </code></pre></div> <p>You can follow a similar workflow with the Functional API or the model subclassing API. Make sure to call <code>compile()</code> <em>after</em> changing the value of <code>trainable</code> in order for your changes to be taken into account. Calling <code>compile()</code> will freeze the state of the training step of the model.</p> <p><strong>2) When using a custom training loop:</strong></p> <p>When writing a training loop, make sure to only update weights that are part of <code>model.trainable_weights</code> (and not all <code>model.weights</code>). Here's a simple TensorFlow example:</p> <div class="codehilite"><pre><span></span><code><span class="n">model</span> <span class="o">=</span> <span class="n">Sequential</span><span class="p">([</span> <span class="n">ResNet50Base</span><span class="p">(</span><span class="n">input_shape</span><span class="o">=</span><span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">weights</span><span class="o">=</span><span class="s1">&#39;pretrained&#39;</span><span class="p">),</span> <span class="n">Dense</span><span class="p">(</span><span class="mi">10</span><span class="p">),</span> <span class="p">])</span> <span class="n">model</span><span class="o">.</span><span class="n">layers</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">trainable</span> <span class="o">=</span> <span class="kc">False</span> <span class="c1"># Freeze ResNet50Base.</span> <span class="c1"># Iterate over the batches of a dataset.</span> <span class="k">for</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">targets</span> <span class="ow">in</span> <span class="n">dataset</span><span class="p">:</span> <span class="c1"># Open a GradientTape.</span> <span class="k">with</span> <span class="n">tf</span><span class="o">.</span><span class="n">GradientTape</span><span class="p">()</span> <span class="k">as</span> <span class="n">tape</span><span class="p">:</span> <span class="c1"># Forward pass.</span> <span class="n">predictions</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">inputs</span><span class="p">)</span> <span class="c1"># Compute the loss value for this batch.</span> <span class="n">loss_value</span> <span class="o">=</span> <span class="n">loss_fn</span><span class="p">(</span><span class="n">targets</span><span class="p">,</span> <span class="n">predictions</span><span class="p">)</span> <span class="c1"># Get gradients of loss wrt the *trainable* weights.</span> <span class="n">gradients</span> <span class="o">=</span> <span class="n">tape</span><span class="o">.</span><span class="n">gradient</span><span class="p">(</span><span class="n">loss_value</span><span class="p">,</span> <span class="n">model</span><span class="o">.</span><span class="n">trainable_weights</span><span class="p">)</span> <span class="c1"># Update the weights of the model.</span> <span class="n">optimizer</span><span class="o">.</span><span class="n">apply_gradients</span><span class="p">(</span><span class="nb">zip</span><span class="p">(</span><span class="n">gradients</span><span class="p">,</span> <span class="n">model</span><span class="o">.</span><span class="n">trainable_weights</span><span class="p">))</span> </code></pre></div> <p><strong>Interaction between <code>trainable</code> and <code>compile()</code></strong></p> <p>Calling <code>compile()</code> on a model is meant to "freeze" the behavior of that model. This implies that the <code>trainable</code> attribute values at the time the model is compiled should be preserved throughout the lifetime of that model, until <code>compile</code> is called again. Hence, if you change <code>trainable</code>, make sure to call <code>compile()</code> again on your model for your changes to be taken into account.</p> <p>For instance, if two models A &amp; B share some layers, and:</p> <ul> <li>Model A gets compiled</li> <li>The <code>trainable</code> attribute value on the shared layers is changed</li> <li>Model B is compiled</li> </ul> <p>Then model A and B are using different <code>trainable</code> values for the shared layers. This mechanism is critical for most existing GAN implementations, which do:</p> <div class="codehilite"><pre><span></span><code><span class="n">discriminator</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="o">...</span><span class="p">)</span> <span class="c1"># the weights of `discriminator` should be updated when `discriminator` is trained</span> <span class="n">discriminator</span><span class="o">.</span><span class="n">trainable</span> <span class="o">=</span> <span class="kc">False</span> <span class="n">gan</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="o">...</span><span class="p">)</span> <span class="c1"># `discriminator` is a submodel of `gan`, which should not be updated when `gan` is trained</span> </code></pre></div> <hr /> <h3 id="whats-the-difference-between-the-training-argument-in-call-and-the-trainable-attribute">What's the difference between the <code>training</code> argument in <code>call()</code> and the <code>trainable</code> attribute?</h3> <p><code>training</code> is a boolean argument in <code>call</code> that determines whether the call should be run in inference mode or training mode. For example, in training mode, a <code>Dropout</code> layer applies random dropout and rescales the output. In inference mode, the same layer does nothing. Example:</p> <div class="codehilite"><pre><span></span><code><span class="n">y</span> <span class="o">=</span> <span class="n">Dropout</span><span class="p">(</span><span class="mf">0.5</span><span class="p">)(</span><span class="n">x</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="c1"># Applies dropout at training time *and* inference time</span> </code></pre></div> <p><code>trainable</code> is a boolean layer attribute that determines the trainable weights of the layer should be updated to minimize the loss during training. If <code>layer.trainable</code> is set to <code>False</code>, then <code>layer.trainable_weights</code> will always be an empty list. Example:</p> <div class="codehilite"><pre><span></span><code><span class="n">model</span> <span class="o">=</span> <span class="n">Sequential</span><span class="p">([</span> <span class="n">ResNet50Base</span><span class="p">(</span><span class="n">input_shape</span><span class="o">=</span><span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">3</span><span class="p">),</span> <span class="n">weights</span><span class="o">=</span><span class="s1">&#39;pretrained&#39;</span><span class="p">),</span> <span class="n">Dense</span><span class="p">(</span><span class="mi">10</span><span class="p">),</span> <span class="p">])</span> <span class="n">model</span><span class="o">.</span><span class="n">layers</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">trainable</span> <span class="o">=</span> <span class="kc">False</span> <span class="c1"># Freeze ResNet50Base.</span> <span class="k">assert</span> <span class="n">model</span><span class="o">.</span><span class="n">layers</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">trainable_weights</span> <span class="o">==</span> <span class="p">[]</span> <span class="c1"># ResNet50Base has no trainable weights.</span> <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">model</span><span class="o">.</span><span class="n">trainable_weights</span><span class="p">)</span> <span class="o">==</span> <span class="mi">2</span> <span class="c1"># Just the bias &amp; kernel of the Dense layer.</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="o">...</span><span class="p">)</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span><span class="o">...</span><span class="p">)</span> <span class="c1"># Train Dense while excluding ResNet50Base.</span> </code></pre></div> <p>As you can see, "inference mode vs training mode" and "layer weight trainability" are two very different concepts.</p> <p>You could imagine the following: a dropout layer where the scaling factor is learned during training, via backpropagation. Let's name it <code>AutoScaleDropout</code>. This layer would have simultaneously a trainable state, and a different behavior in inference and training. Because the <code>trainable</code> attribute and the <code>training</code> call argument are independent, you can do the following:</p> <div class="codehilite"><pre><span></span><code><span class="n">layer</span> <span class="o">=</span> <span class="n">AutoScaleDropout</span><span class="p">(</span><span class="mf">0.5</span><span class="p">)</span> <span class="c1"># Applies dropout at training time *and* inference time </span> <span class="c1"># *and* learns the scaling factor during training</span> <span class="n">y</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="k">assert</span> <span class="nb">len</span><span class="p">(</span><span class="n">layer</span><span class="o">.</span><span class="n">trainable_weights</span><span class="p">)</span> <span class="o">==</span> <span class="mi">1</span> </code></pre></div> <div class="codehilite"><pre><span></span><code><span class="c1"># Applies dropout at training time *and* inference time </span> <span class="c1"># with a *frozen* scaling factor</span> <span class="n">layer</span> <span class="o">=</span> <span class="n">AutoScaleDropout</span><span class="p">(</span><span class="mf">0.5</span><span class="p">)</span> <span class="n">layer</span><span class="o">.</span><span class="n">trainable</span> <span class="o">=</span> <span class="kc">False</span> <span class="n">y</span> <span class="o">=</span> <span class="n">layer</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> </code></pre></div> <p><strong><em>Special case of the <code>BatchNormalization</code> layer</em></strong></p> <p>For a <code>BatchNormalization</code> layer, setting <code>bn.trainable = False</code> will also make its <code>training</code> call argument default to <code>False</code>, meaning that the layer will no update its state during training.</p> <p>This behavior only applies for <code>BatchNormalization</code>. For every other layer, weight trainability and "inference vs training mode" remain independent.</p> <hr /> <h3 id="in-fit-how-is-the-validation-split-computed">In <code>fit()</code>, how is the validation split computed?</h3> <p>If you set the <code>validation_split</code> argument in <code>model.fit</code> to e.g. 0.1, then the validation data used will be the <em>last 10%</em> of the data. If you set it to 0.25, it will be the last 25% of the data, etc. Note that the data isn't shuffled before extracting the validation split, so the validation is literally just the <em>last</em> x% of samples in the input you passed.</p> <p>The same validation set is used for all epochs (within the same call to <code>fit</code>).</p> <p>Note that the <code>validation_split</code> option is only available if your data is passed as Numpy arrays (not <a href="https://www.tensorflow.org/api_docs/python/tf/data/Datasets"><code>tf.data.Datasets</code></a>, which are not indexable).</p> <hr /> <h3 id="in-fit-is-the-data-shuffled-during-training">In <code>fit()</code>, is the data shuffled during training?</h3> <p>If you pass your data as NumPy arrays and if the <code>shuffle</code> argument in <code>model.fit()</code> is set to <code>True</code> (which is the default), the training data will be globally randomly shuffled at each epoch.</p> <p>If you pass your data as a <a href="https://www.tensorflow.org/api_docs/python/tf/data/Dataset"><code>tf.data.Dataset</code></a> object and if the <code>shuffle</code> argument in <code>model.fit()</code> is set to <code>True</code>, the dataset will be locally shuffled (buffered shuffling).</p> <p>When using <a href="https://www.tensorflow.org/api_docs/python/tf/data/Dataset"><code>tf.data.Dataset</code></a> objects, prefer shuffling your data beforehand (e.g. by calling <code>dataset = dataset.shuffle(buffer_size)</code>) so as to be in control of the buffer size.</p> <p>Validation data is never shuffled.</p> <hr /> <h3 id="whats-the-recommended-way-to-monitor-my-metrics-when-training-with-fit">What's the recommended way to monitor my metrics when training with <code>fit()</code>?</h3> <p>Loss values and metric values are reported via the default progress bar displayed by calls to <code>fit()</code>. However, staring at changing ascii numbers in a console is not an optimal metric-monitoring experience. We recommend the use of <a href="https://www.tensorflow.org/tensorboard">TensorBoard</a>, which will display nice-looking graphs of your training and validation metrics, regularly updated during training, which you can access from your browser.</p> <p>You can use TensorBoard with <code>fit()</code> via the <a href="/api/callbacks/tensorboard/"><code>TensorBoard</code> callback</a>.</p> <hr /> <h3 id="what-if-i-need-to-customize-what-fit-does">What if I need to customize what <code>fit()</code> does?</h3> <p>You have two options:</p> <p><strong>1) Subclass the <code>Model</code> class and override the <code>train_step</code> (and <code>test_step</code>) methods</strong></p> <p>This is a better option if you want to use custom update rules but still want to leverage the functionality provided by <code>fit()</code>, such as callbacks, efficient step fusing, etc.</p> <p>Note that this pattern does not prevent you from building models with the Functional API, in which case you will use the class you created to instantiate the model with the <code>inputs</code> and <code>outputs</code>. Same goes for Sequential models, in which case you will subclass <a href="/api/models/sequential#sequential-class"><code>keras.Sequential</code></a> and override its <code>train_step</code> instead of <a href="/api/models/model#model-class"><code>keras.Model</code></a>.</p> <p>See the following guides:</p> <ul> <li><a href="/guides/custom_train_step_in_jax/">Writing a custom train step in JAX</a></li> <li><a href="/guides/custom_train_step_in_tensorflow/">Writing a custom train step in TensorFlow</a></li> <li><a href="/guides/custom_train_step_in_torch/">Writing a custom train step in PyTorch</a></li> </ul> <p><strong>2) Write a low-level custom training loop</strong></p> <p>This is a good option if you want to be in control of every last little detail &ndash; though it can be somewhat verbose.</p> <p>See the following guides:</p> <ul> <li><a href="/guides/writing_a_custom_training_loop_in_jax/">Writing a custom training loop in JAX</a></li> <li><a href="/guides/writing_a_custom_training_loop_in_tensorflow/">Writing a custom training loop in TensorFlow</a></li> <li><a href="/guides/writing_a_custom_training_loop_in_torch/">Writing a custom training loop in PyTorch</a></li> </ul> <hr /> <h3 id="whats-the-difference-between-model-methods-predict-and-call">What's the difference between <code>Model</code> methods <code>predict()</code> and <code>__call__()</code>?</h3> <p>Let's answer with an extract from <a href="https://www.manning.com/books/deep-learning-with-python-second-edition?a_aid=keras">Deep Learning with Python, Second Edition</a>:</p> <blockquote> <p>Both <code>y = model.predict(x)</code> and <code>y = model(x)</code> (where <code>x</code> is an array of input data) mean "run the model on <code>x</code> and retrieve the output <code>y</code>." Yet they aren't exactly the same thing.</p> <p><code>predict()</code> loops over the data in batches (in fact, you can specify the batch size via <code>predict(x, batch_size=64)</code>), and it extracts the NumPy value of the outputs. It's schematically equivalent to this:</p> </blockquote> <div class="codehilite"><pre><span></span><code><span class="k">def</span> <span class="nf">predict</span><span class="p">(</span><span class="n">x</span><span class="p">):</span> <span class="n">y_batches</span> <span class="o">=</span> <span class="p">[]</span> <span class="k">for</span> <span class="n">x_batch</span> <span class="ow">in</span> <span class="n">get_batches</span><span class="p">(</span><span class="n">x</span><span class="p">):</span> <span class="n">y_batch</span> <span class="o">=</span> <span class="n">model</span><span class="p">(</span><span class="n">x_batch</span><span class="p">)</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> <span class="n">y_batches</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">y_batch</span><span class="p">)</span> <span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">(</span><span class="n">y_batches</span><span class="p">)</span> </code></pre></div> <blockquote> <p>This means that <code>predict()</code> calls can scale to very large arrays. Meanwhile, <code>model(x)</code> happens in-memory and doesn't scale. On the other hand, <code>predict()</code> is not differentiable: you cannot retrieve its gradient if you call it in a <code>GradientTape</code> scope.</p> <p>You should use <code>model(x)</code> when you need to retrieve the gradients of the model call, and you should use <code>predict()</code> if you just need the output value. In other words, always use <code>predict()</code> unless you're in the middle of writing a low-level gradient descent loop (as we are now).</p> </blockquote> <hr /> <h2 id="modelingrelated-questions">Modeling-related questions</h2> <h3 id="how-can-i-obtain-the-output-of-an-intermediate-layer-feature-extraction">How can I obtain the output of an intermediate layer (feature extraction)?</h3> <p>In the Functional API and Sequential API, if a layer has been called exactly once, you can retrieve its output via <code>layer.output</code> and its input via <code>layer.input</code>. This enables you do quickly instantiate feature-extraction models, like this one:</p> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">keras</span> <span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">layers</span> <span class="n">model</span> <span class="o">=</span> <span class="n">Sequential</span><span class="p">([</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s1">&#39;relu&#39;</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s1">&#39;relu&#39;</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">MaxPooling2D</span><span class="p">(</span><span class="mi">2</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s1">&#39;relu&#39;</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Conv2D</span><span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s1">&#39;relu&#39;</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">GlobalMaxPooling2D</span><span class="p">(),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">10</span><span class="p">),</span> <span class="p">])</span> <span class="n">extractor</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="o">=</span><span class="n">model</span><span class="o">.</span><span class="n">inputs</span><span class="p">,</span> <span class="n">outputs</span><span class="o">=</span><span class="p">[</span><span class="n">layer</span><span class="o">.</span><span class="n">output</span> <span class="k">for</span> <span class="n">layer</span> <span class="ow">in</span> <span class="n">model</span><span class="o">.</span><span class="n">layers</span><span class="p">])</span> <span class="n">features</span> <span class="o">=</span> <span class="n">extractor</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> </code></pre></div> <p>Naturally, this is not possible with models that are subclasses of <code>Model</code> that override <code>call</code>.</p> <p>Here's another example: instantiating a <code>Model</code> that returns the output of a specific named layer:</p> <div class="codehilite"><pre><span></span><code><span class="n">model</span> <span class="o">=</span> <span class="o">...</span> <span class="c1"># create the original model</span> <span class="n">layer_name</span> <span class="o">=</span> <span class="s1">&#39;my_layer&#39;</span> <span class="n">intermediate_layer_model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Model</span><span class="p">(</span><span class="n">inputs</span><span class="o">=</span><span class="n">model</span><span class="o">.</span><span class="n">input</span><span class="p">,</span> <span class="n">outputs</span><span class="o">=</span><span class="n">model</span><span class="o">.</span><span class="n">get_layer</span><span class="p">(</span><span class="n">layer_name</span><span class="p">)</span><span class="o">.</span><span class="n">output</span><span class="p">)</span> <span class="n">intermediate_output</span> <span class="o">=</span> <span class="n">intermediate_layer_model</span><span class="p">(</span><span class="n">data</span><span class="p">)</span> </code></pre></div> <hr /> <h3 id="how-can-i-use-pretrained-models-in-keras">How can I use pre-trained models in Keras?</h3> <p>You could leverage the <a href="/api/applications/">models available in <code>keras.applications</code></a>, or the models available in <a href="/keras_cv/">KerasCV</a> and <a href="/keras_hub/">KerasHub</a>.</p> <hr /> <h3 id="how-can-i-use-stateful-rnns">How can I use stateful RNNs?</h3> <p>Making a RNN stateful means that the states for the samples of each batch will be reused as initial states for the samples in the next batch.</p> <p>When using stateful RNNs, it is therefore assumed that:</p> <ul> <li>all batches have the same number of samples</li> <li>If <code>x1</code> and <code>x2</code> are successive batches of samples, then <code>x2[i]</code> is the follow-up sequence to <code>x1[i]</code>, for every <code>i</code>.</li> </ul> <p>To use statefulness in RNNs, you need to:</p> <ul> <li>explicitly specify the batch size you are using, by passing a <code>batch_size</code> argument to the first layer in your model. E.g. <code>batch_size=32</code> for a 32-samples batch of sequences of 10 timesteps with 16 features per timestep.</li> <li>set <code>stateful=True</code> in your RNN layer(s).</li> <li>specify <code>shuffle=False</code> when calling <code>fit()</code>.</li> </ul> <p>To reset the states accumulated:</p> <ul> <li>use <code>model.reset_states()</code> to reset the states of all layers in the model</li> <li>use <code>layer.reset_states()</code> to reset the states of a specific stateful RNN layer</li> </ul> <p>Example:</p> <div class="codehilite"><pre><span></span><code><span class="kn">import</span> <span class="nn">keras</span> <span class="kn">from</span> <span class="nn">keras</span> <span class="kn">import</span> <span class="n">layers</span> <span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> <span class="n">x</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">random</span><span class="p">((</span><span class="mi">32</span><span class="p">,</span> <span class="mi">21</span><span class="p">,</span> <span class="mi">16</span><span class="p">))</span> <span class="c1"># this is our input data, of shape (32, 21, 16)</span> <span class="c1"># we will feed it to our model in sequences of length 10</span> <span class="n">model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">()</span> <span class="n">model</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">LSTM</span><span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="n">input_shape</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="mi">16</span><span class="p">),</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span> <span class="n">stateful</span><span class="o">=</span><span class="kc">True</span><span class="p">))</span> <span class="n">model</span><span class="o">.</span><span class="n">add</span><span class="p">(</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">16</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s1">&#39;softmax&#39;</span><span class="p">))</span> <span class="n">model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span><span class="n">optimizer</span><span class="o">=</span><span class="s1">&#39;rmsprop&#39;</span><span class="p">,</span> <span class="n">loss</span><span class="o">=</span><span class="s1">&#39;categorical_crossentropy&#39;</span><span class="p">)</span> <span class="c1"># we train the network to predict the 11th timestep given the first 10:</span> <span class="n">model</span><span class="o">.</span><span class="n">train_on_batch</span><span class="p">(</span><span class="n">x</span><span class="p">[:,</span> <span class="p">:</span><span class="mi">10</span><span class="p">,</span> <span class="p">:],</span> <span class="n">np</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">x</span><span class="p">[:,</span> <span class="mi">10</span><span class="p">,</span> <span class="p">:],</span> <span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="mi">16</span><span class="p">)))</span> <span class="c1"># the state of the network has changed. We can feed the follow-up sequences:</span> <span class="n">model</span><span class="o">.</span><span class="n">train_on_batch</span><span class="p">(</span><span class="n">x</span><span class="p">[:,</span> <span class="mi">10</span><span class="p">:</span><span class="mi">20</span><span class="p">,</span> <span class="p">:],</span> <span class="n">np</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">x</span><span class="p">[:,</span> <span class="mi">20</span><span class="p">,</span> <span class="p">:],</span> <span class="p">(</span><span class="mi">32</span><span class="p">,</span> <span class="mi">16</span><span class="p">)))</span> <span class="c1"># let&#39;s reset the states of the LSTM layer:</span> <span class="n">model</span><span class="o">.</span><span class="n">reset_states</span><span class="p">()</span> <span class="c1"># another way to do it in this case:</span> <span class="n">model</span><span class="o">.</span><span class="n">layers</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">.</span><span class="n">reset_states</span><span class="p">()</span> </code></pre></div> <p>Note that the methods <code>predict</code>, <code>fit</code>, <code>train_on_batch</code>, etc. will <em>all</em> update the states of the stateful layers in a model. This allows you to do not only stateful training, but also stateful prediction.</p> <hr /> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>

Pages: 1 2 3 4 5 6 7 8 9 10