CINXE.COM
The Functional API | TensorFlow Core
<!doctype html> <html lang="en" dir="ltr"> <head> <meta name="google-signin-client-id" content="157101835696-ooapojlodmuabs2do2vuhhnf90bccmoi.apps.googleusercontent.com"> <meta name="google-signin-scope" content="profile email https://www.googleapis.com/auth/developerprofiles https://www.googleapis.com/auth/developerprofiles.award"> <meta property="og:site_name" content="TensorFlow"> <meta property="og:type" content="website"><meta name="theme-color" content="#ff6f00"><meta charset="utf-8"> <meta content="IE=Edge" http-equiv="X-UA-Compatible"> <meta name="viewport" content="width=device-width, initial-scale=1"> <link rel="manifest" href="/_pwa/tensorflow/manifest.json" crossorigin="use-credentials"> <link rel="preconnect" href="//www.gstatic.com" crossorigin> <link rel="preconnect" href="//fonts.gstatic.com" crossorigin> <link rel="preconnect" href="//fonts.googleapis.com" crossorigin> <link rel="preconnect" href="//apis.google.com" crossorigin> <link rel="preconnect" href="//www.google-analytics.com" crossorigin><link rel="stylesheet" href="//fonts.googleapis.com/css?family=Google+Sans:400,500|Roboto:400,400italic,500,500italic,700,700italic|Roboto+Mono:400,500,700&display=swap"> <link rel="stylesheet" href="//fonts.googleapis.com/css2?family=Material+Icons&family=Material+Symbols+Outlined&display=block"><link rel="stylesheet" href="https://www.gstatic.com/devrel-devsite/prod/v870e399c64f7c43c99a3043db4b3a74327bb93d0914e84a0c3dba90bbfd67625/tensorflow/css/app.css"> <link rel="shortcut icon" href="https://www.gstatic.com/devrel-devsite/prod/v870e399c64f7c43c99a3043db4b3a74327bb93d0914e84a0c3dba90bbfd67625/tensorflow/images/favicon.png"> <link rel="apple-touch-icon" href="https://www.gstatic.com/devrel-devsite/prod/v870e399c64f7c43c99a3043db4b3a74327bb93d0914e84a0c3dba90bbfd67625/tensorflow/images/apple-touch-icon-180x180.png"><link rel="canonical" href="https://www.tensorflow.org/guide/keras/functional_api"><link rel="search" type="application/opensearchdescription+xml" title="TensorFlow" href="https://www.tensorflow.org/s/opensearch.xml"> <link rel="alternate" hreflang="en" href="https://www.tensorflow.org/guide/keras/functional_api" /><link rel="alternate" hreflang="x-default" href="https://www.tensorflow.org/guide/keras/functional_api" /><title>The Functional API | TensorFlow Core</title> <meta property="og:title" content="The Functional API | TensorFlow Core"><meta name="description" content="Complete guide to the functional API."> <meta property="og:description" content="Complete guide to the functional API."><meta property="og:url" content="https://www.tensorflow.org/guide/keras/functional_api"><meta property="og:image" content="https://www.tensorflow.org/static/images/tf_logo_social.png"> <meta property="og:image:width" content="1200"> <meta property="og:image:height" content="675"><meta property="og:locale" content="en"><meta name="twitter:card" content="summary_large_image"><script type="application/ld+json"> { "@context": "https://schema.org", "@type": "Article", "headline": "The Functional API" } </script><script type="application/ld+json"> { "@context": "https://schema.org", "@type": "BreadcrumbList", "itemListElement": [{ "@type": "ListItem", "position": 1, "name": "TensorFlow Core", "item": "https://www.tensorflow.org/tutorials" },{ "@type": "ListItem", "position": 2, "name": "The Functional API", "item": "https://www.tensorflow.org/guide/keras/functional_api" }] } </script> <link rel="stylesheet" href="/extras.css"></head> <body class="" template="page" theme="tensorflow-theme" type="article" layout="docs" display-toc pending> <devsite-progress type="indeterminate" id="app-progress"></devsite-progress> <section class="devsite-wrapper"> <devsite-cookie-notification-bar></devsite-cookie-notification-bar><devsite-header role="banner"> <div class="devsite-header--inner nocontent"> <div class="devsite-top-logo-row-wrapper-wrapper"> <div class="devsite-top-logo-row-wrapper"> <div class="devsite-top-logo-row"> <button type="button" id="devsite-hamburger-menu" class="devsite-header-icon-button button-flat material-icons gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Navigation menu button" visually-hidden aria-label="Open menu"> </button> <div class="devsite-product-name-wrapper"> <a href="/" class="devsite-site-logo-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Site logo" track-type="globalNav" track-name="tensorFlow" track-metadata-position="nav" track-metadata-eventDetail="nav"> <picture> <img src="https://www.gstatic.com/devrel-devsite/prod/v870e399c64f7c43c99a3043db4b3a74327bb93d0914e84a0c3dba90bbfd67625/tensorflow/images/lockup.svg" class="devsite-site-logo" alt="TensorFlow"> </picture> </a> <span class="devsite-product-name"> <ul class="devsite-breadcrumb-list" > <li class="devsite-breadcrumb-item "> </li> </ul> </span> </div> <div class="devsite-top-logo-row-middle"> <div class="devsite-header-upper-tabs"> <devsite-tabs class="upper-tabs"> <nav class="devsite-tabs-wrapper" aria-label="Upper tabs"> <tab > <a href="https://www.tensorflow.org/install" track-metadata-eventdetail="https://www.tensorflow.org/install" class="devsite-tabs-content gc-analytics-event " track-type="nav" track-metadata-position="nav - install" track-metadata-module="primary nav" data-category="Site-Wide Custom Events" data-label="Tab: Install" track-name="install" > Install </a> </tab> <tab class="devsite-dropdown devsite-active "> <a href="https://www.tensorflow.org/learn" track-metadata-eventdetail="https://www.tensorflow.org/learn" class="devsite-tabs-content gc-analytics-event " track-type="nav" track-metadata-position="nav - learn" track-metadata-module="primary nav" aria-label="Learn, selected" data-category="Site-Wide Custom Events" data-label="Tab: Learn" track-name="learn" > Learn </a> <a href="#" role="button" aria-haspopup="true" aria-expanded="false" aria-label="Dropdown menu for Learn" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/learn" track-metadata-position="nav - learn" track-metadata-module="primary nav" data-category="Site-Wide Custom Events" data-label="Tab: Learn" track-name="learn" class="devsite-tabs-dropdown-toggle devsite-icon devsite-icon-arrow-drop-down"></a> <div class="devsite-tabs-dropdown" aria-label="submenu" hidden> <div class="devsite-tabs-dropdown-content"> <div class="devsite-tabs-dropdown-column tfo-menu-column-learn"> <ul class="devsite-tabs-dropdown-section "> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/learn" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/learn" track-metadata-position="nav - learn" track-metadata-module="tertiary nav" tooltip > <div class="devsite-nav-item-title"> Introduction </div> <div class="devsite-nav-item-description"> New to TensorFlow? </div> </a> </li> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/tutorials" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/tutorials" track-metadata-position="nav - learn" track-metadata-module="tertiary nav" tooltip > <div class="devsite-nav-item-title"> Tutorials </div> <div class="devsite-nav-item-description"> Learn how to use TensorFlow with end-to-end examples </div> </a> </li> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/guide" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/guide" track-metadata-position="nav - learn" track-metadata-module="tertiary nav" tooltip > <div class="devsite-nav-item-title"> Guide </div> <div class="devsite-nav-item-description"> Learn framework concepts and components </div> </a> </li> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/resources/learn-ml" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/resources/learn-ml" track-metadata-position="nav - learn" track-metadata-module="tertiary nav" tooltip > <div class="devsite-nav-item-title"> Learn ML </div> <div class="devsite-nav-item-description"> Educational resources to master your path with TensorFlow </div> </a> </li> </ul> </div> </div> </div> </tab> <tab class="devsite-dropdown "> <a href="https://www.tensorflow.org/api" track-metadata-eventdetail="https://www.tensorflow.org/api" class="devsite-tabs-content gc-analytics-event " track-type="nav" track-metadata-position="nav - api" track-metadata-module="primary nav" data-category="Site-Wide Custom Events" data-label="Tab: API" track-name="api" > API </a> <a href="#" role="button" aria-haspopup="true" aria-expanded="false" aria-label="Dropdown menu for API" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/api" track-metadata-position="nav - api" track-metadata-module="primary nav" data-category="Site-Wide Custom Events" data-label="Tab: API" track-name="api" class="devsite-tabs-dropdown-toggle devsite-icon devsite-icon-arrow-drop-down"></a> <div class="devsite-tabs-dropdown" aria-label="submenu" hidden> <div class="devsite-tabs-dropdown-content"> <div class="devsite-tabs-dropdown-column "> <ul class="devsite-tabs-dropdown-section "> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/api/stable" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/api/stable" track-metadata-position="nav - api" track-metadata-module="tertiary nav" tooltip > <div class="devsite-nav-item-title"> TensorFlow (v2.16.1) </div> </a> </li> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/versions" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/versions" track-metadata-position="nav - api" track-metadata-module="tertiary nav" tooltip > <div class="devsite-nav-item-title"> Versions… </div> </a> </li> </ul> </div> <div class="devsite-tabs-dropdown-column "> <ul class="devsite-tabs-dropdown-section "> <li class="devsite-nav-item"> <a href="https://js.tensorflow.org/api/latest/" track-type="nav" track-metadata-eventdetail="https://js.tensorflow.org/api/latest/" track-metadata-position="nav - api" track-metadata-module="tertiary nav" tooltip > <div class="devsite-nav-item-title"> TensorFlow.js </div> </a> </li> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/lite/api_docs" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/lite/api_docs" track-metadata-position="nav - api" track-metadata-module="tertiary nav" tooltip > <div class="devsite-nav-item-title"> TensorFlow Lite </div> </a> </li> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/tfx/api_docs" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/tfx/api_docs" track-metadata-position="nav - api" track-metadata-module="tertiary nav" tooltip > <div class="devsite-nav-item-title"> TFX </div> </a> </li> </ul> </div> </div> </div> </tab> <tab class="devsite-dropdown "> <a href="https://www.tensorflow.org/resources/models-datasets" track-metadata-eventdetail="https://www.tensorflow.org/resources/models-datasets" class="devsite-tabs-content gc-analytics-event " track-type="nav" track-metadata-position="nav - ecosystem" track-metadata-module="primary nav" data-category="Site-Wide Custom Events" data-label="Tab: Ecosystem" track-name="ecosystem" > Ecosystem </a> <a href="#" role="button" aria-haspopup="true" aria-expanded="false" aria-label="Dropdown menu for Ecosystem" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/resources/models-datasets" track-metadata-position="nav - ecosystem" track-metadata-module="primary nav" data-category="Site-Wide Custom Events" data-label="Tab: Ecosystem" track-name="ecosystem" class="devsite-tabs-dropdown-toggle devsite-icon devsite-icon-arrow-drop-down"></a> <div class="devsite-tabs-dropdown" aria-label="submenu" hidden> <div class="devsite-tabs-dropdown-content"> <div class="devsite-tabs-dropdown-column "> <ul class="devsite-tabs-dropdown-section "> <li class="devsite-nav-title" role="heading" tooltip>LIBRARIES</li> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/js" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/js" track-metadata-position="nav - ecosystem" track-metadata-module="tertiary nav" track-metadata-module_headline="libraries" tooltip > <div class="devsite-nav-item-title"> TensorFlow.js </div> <div class="devsite-nav-item-description"> Develop web ML applications in JavaScript </div> </a> </li> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/lite" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/lite" track-metadata-position="nav - ecosystem" track-metadata-module="tertiary nav" track-metadata-module_headline="libraries" tooltip > <div class="devsite-nav-item-title"> TensorFlow Lite </div> <div class="devsite-nav-item-description"> Deploy ML on mobile, microcontrollers and other edge devices </div> </a> </li> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/tfx" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/tfx" track-metadata-position="nav - ecosystem" track-metadata-module="tertiary nav" track-metadata-module_headline="libraries" tooltip > <div class="devsite-nav-item-title"> TFX </div> <div class="devsite-nav-item-description"> Build production ML pipelines </div> </a> </li> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/resources/libraries-extensions" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/resources/libraries-extensions" track-metadata-position="nav - ecosystem" track-metadata-module="tertiary nav" track-metadata-module_headline="libraries" tooltip > <div class="devsite-nav-item-title"> All libraries </div> <div class="devsite-nav-item-description"> Create advanced models and extend TensorFlow </div> </a> </li> </ul> </div> <div class="devsite-tabs-dropdown-column "> <ul class="devsite-tabs-dropdown-section "> <li class="devsite-nav-title" role="heading" tooltip>RESOURCES</li> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/resources/models-datasets" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/resources/models-datasets" track-metadata-position="nav - ecosystem" track-metadata-module="tertiary nav" track-metadata-module_headline="resources" tooltip > <div class="devsite-nav-item-title"> Models & datasets </div> <div class="devsite-nav-item-description"> Pre-trained models and datasets built by Google and the community </div> </a> </li> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/resources/tools" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/resources/tools" track-metadata-position="nav - ecosystem" track-metadata-module="tertiary nav" track-metadata-module_headline="resources" tooltip > <div class="devsite-nav-item-title"> Tools </div> <div class="devsite-nav-item-description"> Tools to support and accelerate TensorFlow workflows </div> </a> </li> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/responsible_ai" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/responsible_ai" track-metadata-position="nav - ecosystem" track-metadata-module="tertiary nav" track-metadata-module_headline="resources" tooltip > <div class="devsite-nav-item-title"> Responsible AI </div> <div class="devsite-nav-item-description"> Resources for every stage of the ML workflow </div> </a> </li> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/resources/recommendation-systems" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/resources/recommendation-systems" track-metadata-position="nav - ecosystem" track-metadata-module="tertiary nav" track-metadata-module_headline="resources" tooltip > <div class="devsite-nav-item-title"> Recommendation systems </div> <div class="devsite-nav-item-description"> Build recommendation systems with open source tools </div> </a> </li> </ul> </div> </div> </div> </tab> <tab class="devsite-dropdown "> <a href="https://www.tensorflow.org/community" track-metadata-eventdetail="https://www.tensorflow.org/community" class="devsite-tabs-content gc-analytics-event " track-type="nav" track-metadata-position="nav - community" track-metadata-module="primary nav" data-category="Site-Wide Custom Events" data-label="Tab: Community" track-name="community" > Community </a> <a href="#" role="button" aria-haspopup="true" aria-expanded="false" aria-label="Dropdown menu for Community" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/community" track-metadata-position="nav - community" track-metadata-module="primary nav" data-category="Site-Wide Custom Events" data-label="Tab: Community" track-name="community" class="devsite-tabs-dropdown-toggle devsite-icon devsite-icon-arrow-drop-down"></a> <div class="devsite-tabs-dropdown" aria-label="submenu" hidden> <div class="devsite-tabs-dropdown-content"> <div class="devsite-tabs-dropdown-column "> <ul class="devsite-tabs-dropdown-section "> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/community/groups" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/community/groups" track-metadata-position="nav - community" track-metadata-module="tertiary nav" tooltip > <div class="devsite-nav-item-title"> Groups </div> <div class="devsite-nav-item-description"> User groups, interest groups and mailing lists </div> </a> </li> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/community/contribute" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/community/contribute" track-metadata-position="nav - community" track-metadata-module="tertiary nav" tooltip > <div class="devsite-nav-item-title"> Contribute </div> <div class="devsite-nav-item-description"> Guide for contributing to code and documentation </div> </a> </li> <li class="devsite-nav-item"> <a href="https://blog.tensorflow.org/" track-type="nav" track-metadata-eventdetail="https://blog.tensorflow.org/" track-metadata-position="nav - community" track-metadata-module="tertiary nav" tooltip > <div class="devsite-nav-item-title"> Blog </div> <div class="devsite-nav-item-description"> Stay up to date with all things TensorFlow </div> </a> </li> <li class="devsite-nav-item"> <a href="https://discuss.tensorflow.org" track-type="nav" track-metadata-eventdetail="https://discuss.tensorflow.org" track-metadata-position="nav - community" track-metadata-module="tertiary nav" tooltip > <div class="devsite-nav-item-title"> Forum </div> <div class="devsite-nav-item-description"> Discussion platform for the TensorFlow community </div> </a> </li> </ul> </div> </div> </div> </tab> <tab class="devsite-dropdown "> <a href="https://www.tensorflow.org/about" track-metadata-eventdetail="https://www.tensorflow.org/about" class="devsite-tabs-content gc-analytics-event " track-type="nav" track-metadata-position="nav - why tensorflow" track-metadata-module="primary nav" data-category="Site-Wide Custom Events" data-label="Tab: Why TensorFlow" track-name="why tensorflow" > Why TensorFlow </a> <a href="#" role="button" aria-haspopup="true" aria-expanded="false" aria-label="Dropdown menu for Why TensorFlow" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/about" track-metadata-position="nav - why tensorflow" track-metadata-module="primary nav" data-category="Site-Wide Custom Events" data-label="Tab: Why TensorFlow" track-name="why tensorflow" class="devsite-tabs-dropdown-toggle devsite-icon devsite-icon-arrow-drop-down"></a> <div class="devsite-tabs-dropdown" aria-label="submenu" hidden> <div class="devsite-tabs-dropdown-content"> <div class="devsite-tabs-dropdown-column "> <ul class="devsite-tabs-dropdown-section "> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/about" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/about" track-metadata-position="nav - why tensorflow" track-metadata-module="tertiary nav" tooltip > <div class="devsite-nav-item-title"> About </div> </a> </li> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/about/case-studies" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/about/case-studies" track-metadata-position="nav - why tensorflow" track-metadata-module="tertiary nav" tooltip > <div class="devsite-nav-item-title"> Case studies </div> </a> </li> </ul> </div> </div> </div> </tab> </nav> </devsite-tabs> </div> <devsite-search enable-signin enable-search enable-suggestions enable-query-completion project-name="TensorFlow Core" tenant-name="TensorFlow" > <form class="devsite-search-form" action="https://www.tensorflow.org/s/results" method="GET"> <div class="devsite-search-container"> <button type="button" search-open class="devsite-search-button devsite-header-icon-button button-flat material-icons" aria-label="Open search"></button> <div class="devsite-searchbox"> <input aria-activedescendant="" aria-autocomplete="list" aria-label="Search" aria-expanded="false" aria-haspopup="listbox" autocomplete="off" class="devsite-search-field devsite-search-query" name="q" placeholder="Search" role="combobox" type="text" value="" > <div class="devsite-search-image material-icons" aria-hidden="true"> </div> <div class="devsite-search-shortcut-icon-container" aria-hidden="true"> <kbd class="devsite-search-shortcut-icon">/</kbd> </div> </div> </div> </form> <button type="button" search-close class="devsite-search-button devsite-header-icon-button button-flat material-icons" aria-label="Close search"></button> </devsite-search> </div> <devsite-language-selector> <ul role="presentation"> <li role="presentation"> <a role="menuitem" lang="en" >English</a> </li> <li role="presentation"> <a role="menuitem" lang="zh_cn" >中文 – 简体</a> </li> </ul> </devsite-language-selector> <a class="devsite-header-link devsite-top-button button gc-analytics-event" href="//github.com/tensorflow" data-category="Site-Wide Custom Events" data-label="Site header link" > GitHub </a> <devsite-user enable-profiles id="devsite-user"> <span class="button devsite-top-button" aria-hidden="true" visually-hidden>Sign in</span> </devsite-user> </div> </div> </div> <div class="devsite-collapsible-section "> <div class="devsite-header-background"> <div class="devsite-product-id-row" > <div class="devsite-product-description-row"> <ul class="devsite-breadcrumb-list" > <li class="devsite-breadcrumb-item "> <a href="https://www.tensorflow.org/tutorials" class="devsite-breadcrumb-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Lower Header" data-value="1" track-type="globalNav" track-name="breadcrumb" track-metadata-position="1" track-metadata-eventdetail="TensorFlow Core" > TensorFlow Core </a> </li> </ul> </div> </div> <div class="devsite-doc-set-nav-row"> <devsite-tabs class="lower-tabs"> <nav class="devsite-tabs-wrapper" aria-label="Lower tabs"> <tab > <a href="https://www.tensorflow.org/tutorials" track-metadata-eventdetail="https://www.tensorflow.org/tutorials" class="devsite-tabs-content gc-analytics-event " track-type="nav" track-metadata-position="nav - tutorials" track-metadata-module="primary nav" data-category="Site-Wide Custom Events" data-label="Tab: Tutorials" track-name="tutorials" > Tutorials </a> </tab> <tab class="devsite-active"> <a href="https://www.tensorflow.org/guide" track-metadata-eventdetail="https://www.tensorflow.org/guide" class="devsite-tabs-content gc-analytics-event " track-type="nav" track-metadata-position="nav - guide" track-metadata-module="primary nav" aria-label="Guide, selected" data-category="Site-Wide Custom Events" data-label="Tab: Guide" track-name="guide" > Guide </a> </tab> <tab > <a href="https://www.tensorflow.org/guide/migrate" track-metadata-eventdetail="https://www.tensorflow.org/guide/migrate" class="devsite-tabs-content gc-analytics-event " track-type="nav" track-metadata-position="nav - migrate to tf2" track-metadata-module="primary nav" data-category="Site-Wide Custom Events" data-label="Tab: Migrate to TF2" track-name="migrate to tf2" > Migrate to TF2 </a> </tab> <tab > <a href="https://github.com/tensorflow/docs/tree/master/site/en/r1" track-metadata-eventdetail="https://github.com/tensorflow/docs/tree/master/site/en/r1" class="devsite-tabs-content gc-analytics-event " track-type="nav" track-metadata-position="nav - tf 1 ↗" track-metadata-module="primary nav" data-category="Site-Wide Custom Events" data-label="Tab: TF 1 ↗" track-name="tf 1 ↗" > TF 1 ↗ </a> </tab> </nav> </devsite-tabs> </div> </div> </div> </div> </devsite-header> <devsite-book-nav scrollbars > <div class="devsite-book-nav-filter" > <span class="filter-list-icon material-icons" aria-hidden="true"></span> <input type="text" placeholder="Filter" aria-label="Type to filter" role="searchbox"> <span class="filter-clear-button hidden" data-title="Clear filter" aria-label="Clear filter" role="button" tabindex="0"></span> </div> <nav class="devsite-book-nav devsite-nav nocontent" aria-label="Side menu"> <div class="devsite-mobile-header"> <button type="button" id="devsite-close-nav" class="devsite-header-icon-button button-flat material-icons gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Close navigation" aria-label="Close navigation"> </button> <div class="devsite-product-name-wrapper"> <a href="/" class="devsite-site-logo-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Site logo" track-type="globalNav" track-name="tensorFlow" track-metadata-position="nav" track-metadata-eventDetail="nav"> <picture> <img src="https://www.gstatic.com/devrel-devsite/prod/v870e399c64f7c43c99a3043db4b3a74327bb93d0914e84a0c3dba90bbfd67625/tensorflow/images/lockup.svg" class="devsite-site-logo" alt="TensorFlow"> </picture> </a> <span class="devsite-product-name"> <ul class="devsite-breadcrumb-list" > <li class="devsite-breadcrumb-item "> </li> </ul> </span> </div> </div> <div class="devsite-book-nav-wrapper"> <div class="devsite-mobile-nav-top"> <ul class="devsite-nav-list"> <li class="devsite-nav-item"> <a href="/install" class="devsite-nav-title gc-analytics-event devsite-nav-has-children " data-category="Site-Wide Custom Events" data-label="Tab: Install" track-name="install" data-category="Site-Wide Custom Events" data-label="Responsive Tab: Install" track-type="globalNav" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Install </span> <span class="devsite-nav-icon material-icons" data-icon="forward" > </span> </a> </li> <li class="devsite-nav-item"> <a href="/learn" class="devsite-nav-title gc-analytics-event devsite-nav-active" data-category="Site-Wide Custom Events" data-label="Tab: Learn" track-name="learn" data-category="Site-Wide Custom Events" data-label="Responsive Tab: Learn" track-type="globalNav" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Learn </span> </a> <ul class="devsite-nav-responsive-tabs devsite-nav-has-menu "> <li class="devsite-nav-item"> <span class="devsite-nav-title" tooltip data-category="Site-Wide Custom Events" data-label="Tab: Learn" track-name="learn" > <span class="devsite-nav-text" tooltip menu="Learn"> More </span> <span class="devsite-nav-icon material-icons" data-icon="forward" menu="Learn"> </span> </span> </li> </ul> <ul class="devsite-nav-responsive-tabs"> <li class="devsite-nav-item"> <a href="/tutorials" class="devsite-nav-title gc-analytics-event devsite-nav-has-children " data-category="Site-Wide Custom Events" data-label="Tab: Tutorials" track-name="tutorials" data-category="Site-Wide Custom Events" data-label="Responsive Tab: Tutorials" track-type="globalNav" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Tutorials </span> <span class="devsite-nav-icon material-icons" data-icon="forward" > </span> </a> </li> <li class="devsite-nav-item"> <a href="/guide" class="devsite-nav-title gc-analytics-event devsite-nav-has-children devsite-nav-active" data-category="Site-Wide Custom Events" data-label="Tab: Guide" track-name="guide" data-category="Site-Wide Custom Events" data-label="Responsive Tab: Guide" track-type="globalNav" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip menu="_book"> Guide </span> <span class="devsite-nav-icon material-icons" data-icon="forward" menu="_book"> </span> </a> </li> <li class="devsite-nav-item"> <a href="/guide/migrate" class="devsite-nav-title gc-analytics-event devsite-nav-has-children " data-category="Site-Wide Custom Events" data-label="Tab: Migrate to TF2" track-name="migrate to tf2" data-category="Site-Wide Custom Events" data-label="Responsive Tab: Migrate to TF2" track-type="globalNav" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Migrate to TF2 </span> <span class="devsite-nav-icon material-icons" data-icon="forward" > </span> </a> </li> <li class="devsite-nav-item"> <a href="https://github.com/tensorflow/docs/tree/master/site/en/r1" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Tab: TF 1 ↗" track-name="tf 1 ↗" data-category="Site-Wide Custom Events" data-label="Responsive Tab: TF 1 ↗" track-type="globalNav" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > TF 1 ↗ </span> </a> </li> </ul> </li> <li class="devsite-nav-item"> <a href="/api" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Tab: API" track-name="api" data-category="Site-Wide Custom Events" data-label="Responsive Tab: API" track-type="globalNav" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > API </span> </a> <ul class="devsite-nav-responsive-tabs devsite-nav-has-menu "> <li class="devsite-nav-item"> <span class="devsite-nav-title" tooltip data-category="Site-Wide Custom Events" data-label="Tab: API" track-name="api" > <span class="devsite-nav-text" tooltip menu="API"> More </span> <span class="devsite-nav-icon material-icons" data-icon="forward" menu="API"> </span> </span> </li> </ul> </li> <li class="devsite-nav-item"> <a href="/resources/models-datasets" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Tab: Ecosystem" track-name="ecosystem" data-category="Site-Wide Custom Events" data-label="Responsive Tab: Ecosystem" track-type="globalNav" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Ecosystem </span> </a> <ul class="devsite-nav-responsive-tabs devsite-nav-has-menu "> <li class="devsite-nav-item"> <span class="devsite-nav-title" tooltip data-category="Site-Wide Custom Events" data-label="Tab: Ecosystem" track-name="ecosystem" > <span class="devsite-nav-text" tooltip menu="Ecosystem"> More </span> <span class="devsite-nav-icon material-icons" data-icon="forward" menu="Ecosystem"> </span> </span> </li> </ul> </li> <li class="devsite-nav-item"> <a href="/community" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Tab: Community" track-name="community" data-category="Site-Wide Custom Events" data-label="Responsive Tab: Community" track-type="globalNav" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Community </span> </a> <ul class="devsite-nav-responsive-tabs devsite-nav-has-menu "> <li class="devsite-nav-item"> <span class="devsite-nav-title" tooltip data-category="Site-Wide Custom Events" data-label="Tab: Community" track-name="community" > <span class="devsite-nav-text" tooltip menu="Community"> More </span> <span class="devsite-nav-icon material-icons" data-icon="forward" menu="Community"> </span> </span> </li> </ul> </li> <li class="devsite-nav-item"> <a href="/about" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Tab: Why TensorFlow" track-name="why tensorflow" data-category="Site-Wide Custom Events" data-label="Responsive Tab: Why TensorFlow" track-type="globalNav" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Why TensorFlow </span> </a> <ul class="devsite-nav-responsive-tabs devsite-nav-has-menu "> <li class="devsite-nav-item"> <span class="devsite-nav-title" tooltip data-category="Site-Wide Custom Events" data-label="Tab: Why TensorFlow" track-name="why tensorflow" > <span class="devsite-nav-text" tooltip menu="Why TensorFlow"> More </span> <span class="devsite-nav-icon material-icons" data-icon="forward" menu="Why TensorFlow"> </span> </span> </li> </ul> </li> <li class="devsite-nav-item"> <a href="//github.com/tensorflow" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: GitHub" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > GitHub </span> </a> </li> </ul> </div> <div class="devsite-mobile-nav-bottom"> <ul class="devsite-nav-list" menu="_book"> <li class="devsite-nav-item"><a href="/guide" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide" ><span class="devsite-nav-text" tooltip>TensorFlow guide</span></a></li> <li class="devsite-nav-item devsite-nav-heading"><div class="devsite-nav-title devsite-nav-title-no-path"> <span class="devsite-nav-text" tooltip>TensorFlow basics</span> </div></li> <li class="devsite-nav-item"><a href="/guide/basics" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/basics" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/basics" ><span class="devsite-nav-text" tooltip>Overview</span></a></li> <li class="devsite-nav-item"><a href="/guide/tensor" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/tensor" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/tensor" ><span class="devsite-nav-text" tooltip>Tensors</span></a></li> <li class="devsite-nav-item"><a href="/guide/variable" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/variable" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/variable" ><span class="devsite-nav-text" tooltip>Variables</span></a></li> <li class="devsite-nav-item"><a href="/guide/autodiff" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/autodiff" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/autodiff" ><span class="devsite-nav-text" tooltip>Automatic differentiation</span></a></li> <li class="devsite-nav-item"><a href="/guide/intro_to_graphs" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/intro_to_graphs" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/intro_to_graphs" ><span class="devsite-nav-text" tooltip>Graphs and functions</span></a></li> <li class="devsite-nav-item"><a href="/guide/intro_to_modules" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/intro_to_modules" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/intro_to_modules" ><span class="devsite-nav-text" tooltip>Modules, layers, and models</span></a></li> <li class="devsite-nav-item"><a href="/guide/basic_training_loops" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/basic_training_loops" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/basic_training_loops" ><span class="devsite-nav-text" tooltip>Training loops</span></a></li> <li class="devsite-nav-item devsite-nav-heading"><div class="devsite-nav-title devsite-nav-title-no-path"> <span class="devsite-nav-text" tooltip>Keras</span> </div></li> <li class="devsite-nav-item"><a href="/guide/keras" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/keras" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/keras" ><span class="devsite-nav-text" tooltip>Overview</span></a></li> <li class="devsite-nav-item"><a href="/guide/keras/sequential_model" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/keras/sequential_model" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/keras/sequential_model" ><span class="devsite-nav-text" tooltip>The Sequential model</span></a></li> <li class="devsite-nav-item"><a href="/guide/keras/functional_api" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/keras/functional_api" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/keras/functional_api" ><span class="devsite-nav-text" tooltip>The Functional API</span></a></li> <li class="devsite-nav-item"><a href="/guide/keras/training_with_built_in_methods" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/keras/training_with_built_in_methods" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/keras/training_with_built_in_methods" ><span class="devsite-nav-text" tooltip>Training & evaluation with the built-in methods</span></a></li> <li class="devsite-nav-item"><a href="/guide/keras/making_new_layers_and_models_via_subclassing" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/keras/making_new_layers_and_models_via_subclassing" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/keras/making_new_layers_and_models_via_subclassing" ><span class="devsite-nav-text" tooltip>Making new layers and models via subclassing</span></a></li> <li class="devsite-nav-item"><a href="/guide/keras/serialization_and_saving" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/keras/serialization_and_saving" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/keras/serialization_and_saving" ><span class="devsite-nav-text" tooltip>Serialization and saving</span></a></li> <li class="devsite-nav-item"><a href="/guide/keras/customizing_saving_and_serialization" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/keras/customizing_saving_and_serialization" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/keras/customizing_saving_and_serialization" ><span class="devsite-nav-text" tooltip>Customizing Saving</span></a></li> <li class="devsite-nav-item"><a href="/guide/keras/preprocessing_layers" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/keras/preprocessing_layers" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/keras/preprocessing_layers" ><span class="devsite-nav-text" tooltip>Working with preprocessing layers</span></a></li> <li class="devsite-nav-item"><a href="/guide/keras/customizing_what_happens_in_fit" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/keras/customizing_what_happens_in_fit" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/keras/customizing_what_happens_in_fit" ><span class="devsite-nav-text" tooltip>Customizing what happens in fit()</span></a></li> <li class="devsite-nav-item"><a href="/guide/keras/writing_a_training_loop_from_scratch" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/keras/writing_a_training_loop_from_scratch" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/keras/writing_a_training_loop_from_scratch" ><span class="devsite-nav-text" tooltip>Writing a training loop from scratch</span></a></li> <li class="devsite-nav-item"><a href="/guide/keras/working_with_rnns" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/keras/working_with_rnns" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/keras/working_with_rnns" ><span class="devsite-nav-text" tooltip>Working with RNNs</span></a></li> <li class="devsite-nav-item"><a href="/guide/keras/understanding_masking_and_padding" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/keras/understanding_masking_and_padding" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/keras/understanding_masking_and_padding" ><span class="devsite-nav-text" tooltip>Understanding masking & padding</span></a></li> <li class="devsite-nav-item"><a href="/guide/keras/writing_your_own_callbacks" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/keras/writing_your_own_callbacks" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/keras/writing_your_own_callbacks" ><span class="devsite-nav-text" tooltip>Writing your own callbacks</span></a></li> <li class="devsite-nav-item"><a href="/guide/keras/transfer_learning" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/keras/transfer_learning" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/keras/transfer_learning" ><span class="devsite-nav-text" tooltip>Transfer learning & fine-tuning</span></a></li> <li class="devsite-nav-item"><a href="/guide/keras/distributed_training" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/keras/distributed_training" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/keras/distributed_training" ><span class="devsite-nav-text" tooltip>Multi-GPU and distributed training</span></a></li> <li class="devsite-nav-item devsite-nav-heading devsite-nav-new"><div class="devsite-nav-title devsite-nav-title-no-path"> <span class="devsite-nav-text" tooltip>Build with Core</span><span class="devsite-nav-icon material-icons" data-icon="new" data-title="New!" aria-hidden="true"></span> </div></li> <li class="devsite-nav-item"><a href="/guide/core" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/core" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/core" ><span class="devsite-nav-text" tooltip>Overview</span></a></li> <li class="devsite-nav-item"><a href="/guide/core/quickstart_core" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/core/quickstart_core" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/core/quickstart_core" ><span class="devsite-nav-text" tooltip>Quickstart for Core</span></a></li> <li class="devsite-nav-item"><a href="/guide/core/logistic_regression_core" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/core/logistic_regression_core" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/core/logistic_regression_core" ><span class="devsite-nav-text" tooltip>Logistic regression</span></a></li> <li class="devsite-nav-item"><a href="/guide/core/mlp_core" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/core/mlp_core" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/core/mlp_core" ><span class="devsite-nav-text" tooltip>Multilayer perceptrons</span></a></li> <li class="devsite-nav-item"><a href="/guide/core/matrix_core" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/core/matrix_core" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/core/matrix_core" ><span class="devsite-nav-text" tooltip>Matrix approximation</span></a></li> <li class="devsite-nav-item"><a href="/guide/core/optimizers_core" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/core/optimizers_core" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/core/optimizers_core" ><span class="devsite-nav-text" tooltip>Custom optimizers</span></a></li> <li class="devsite-nav-item devsite-nav-experimental"><a href="/guide/core/distribution" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/core/distribution" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/core/distribution" ><span class="devsite-nav-text" tooltip>DTensor with Core APIs</span><span class="devsite-nav-icon material-icons" data-icon="experimental" data-title="Experimental!" aria-hidden="true"></span></a></li> <li class="devsite-nav-item devsite-nav-heading"><div class="devsite-nav-title devsite-nav-title-no-path"> <span class="devsite-nav-text" tooltip>TensorFlow in depth</span> </div></li> <li class="devsite-nav-item"><a href="/guide/tensor_slicing" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/tensor_slicing" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/tensor_slicing" ><span class="devsite-nav-text" tooltip>Tensor slicing</span></a></li> <li class="devsite-nav-item"><a href="/guide/advanced_autodiff" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/advanced_autodiff" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/advanced_autodiff" ><span class="devsite-nav-text" tooltip>Advanced autodiff</span></a></li> <li class="devsite-nav-item"><a href="/guide/ragged_tensor" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/ragged_tensor" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/ragged_tensor" ><span class="devsite-nav-text" tooltip>Ragged tensor</span></a></li> <li class="devsite-nav-item"><a href="/guide/sparse_tensor" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/sparse_tensor" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/sparse_tensor" ><span class="devsite-nav-text" tooltip>Sparse tensor</span></a></li> <li class="devsite-nav-item"><a href="/guide/random_numbers" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/random_numbers" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/random_numbers" ><span class="devsite-nav-text" tooltip>Random number generation</span></a></li> <li class="devsite-nav-item devsite-nav-experimental"><a href="/guide/tf_numpy" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/tf_numpy" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/tf_numpy" ><span class="devsite-nav-text" tooltip>NumPy API</span><span class="devsite-nav-icon material-icons" data-icon="experimental" data-title="Experimental!" aria-hidden="true"></span></a></li> <li class="devsite-nav-item devsite-nav-nightly"><a href="/guide/tf_numpy_type_promotion" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/tf_numpy_type_promotion" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/tf_numpy_type_promotion" ><span class="devsite-nav-text" tooltip>NumPy API Type Promotion</span><span class="devsite-nav-icon material-icons" data-icon="nightly" data-title="Nightly build only" aria-hidden="true"></span></a></li> <li class="devsite-nav-item devsite-nav-experimental"><a href="/guide/dtensor_overview" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/dtensor_overview" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/dtensor_overview" ><span class="devsite-nav-text" tooltip>DTensor concepts</span><span class="devsite-nav-icon material-icons" data-icon="experimental" data-title="Experimental!" aria-hidden="true"></span></a></li> <li class="devsite-nav-item"><a href="/guide/effective_tf2" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/effective_tf2" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/effective_tf2" ><span class="devsite-nav-text" tooltip>Thinking in TensorFlow 2</span></a></li> <li class="devsite-nav-item devsite-nav-heading"><div class="devsite-nav-title devsite-nav-title-no-path"> <span class="devsite-nav-text" tooltip>Customization</span> </div></li> <li class="devsite-nav-item"><a href="/guide/create_op" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/create_op" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/create_op" ><span class="devsite-nav-text" tooltip>Create an op</span></a></li> <li class="devsite-nav-item devsite-nav-experimental"><a href="/guide/extension_type" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/extension_type" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/extension_type" ><span class="devsite-nav-text" tooltip>Extension types</span><span class="devsite-nav-icon material-icons" data-icon="experimental" data-title="Experimental!" aria-hidden="true"></span></a></li> <li class="devsite-nav-item devsite-nav-heading"><div class="devsite-nav-title devsite-nav-title-no-path"> <span class="devsite-nav-text" tooltip>Data input pipelines</span> </div></li> <li class="devsite-nav-item"><a href="/guide/data" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/data" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/data" ><span class="devsite-nav-text" tooltip>tf.data</span></a></li> <li class="devsite-nav-item"><a href="/guide/data_performance" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/data_performance" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/data_performance" ><span class="devsite-nav-text" tooltip>Optimize pipeline performance</span></a></li> <li class="devsite-nav-item"><a href="/guide/data_performance_analysis" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/data_performance_analysis" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/data_performance_analysis" ><span class="devsite-nav-text" tooltip>Analyze pipeline performance</span></a></li> <li class="devsite-nav-item devsite-nav-heading"><div class="devsite-nav-title devsite-nav-title-no-path"> <span class="devsite-nav-text" tooltip>Import and export</span> </div></li> <li class="devsite-nav-item"><a href="/guide/checkpoint" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/checkpoint" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/checkpoint" ><span class="devsite-nav-text" tooltip>Checkpoint</span></a></li> <li class="devsite-nav-item"><a href="/guide/saved_model" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/saved_model" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/saved_model" ><span class="devsite-nav-text" tooltip>SavedModel</span></a></li> <li class="devsite-nav-item devsite-nav-new"><a href="/guide/jax2tf" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/jax2tf" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/jax2tf" ><span class="devsite-nav-text" tooltip>Import a JAX model using JAX2TF</span><span class="devsite-nav-icon material-icons" data-icon="new" data-title="New!" aria-hidden="true"></span></a></li> <li class="devsite-nav-item devsite-nav-heading"><div class="devsite-nav-title devsite-nav-title-no-path"> <span class="devsite-nav-text" tooltip>Accelerators</span> </div></li> <li class="devsite-nav-item"><a href="/guide/distributed_training" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/distributed_training" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/distributed_training" ><span class="devsite-nav-text" tooltip>Distributed training</span></a></li> <li class="devsite-nav-item"><a href="/guide/gpu" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/gpu" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/gpu" ><span class="devsite-nav-text" tooltip>GPU</span></a></li> <li class="devsite-nav-item"><a href="/guide/tpu" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/tpu" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/tpu" ><span class="devsite-nav-text" tooltip>TPU</span></a></li> <li class="devsite-nav-item devsite-nav-heading"><div class="devsite-nav-title devsite-nav-title-no-path"> <span class="devsite-nav-text" tooltip>Performance</span> </div></li> <li class="devsite-nav-item"><a href="/guide/function" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/function" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/function" ><span class="devsite-nav-text" tooltip>Better performance with tf.function</span></a></li> <li class="devsite-nav-item"><a href="/guide/profiler" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/profiler" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/profiler" ><span class="devsite-nav-text" tooltip>Profile TensorFlow performance</span></a></li> <li class="devsite-nav-item"><a href="/guide/gpu_performance_analysis" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/gpu_performance_analysis" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/gpu_performance_analysis" ><span class="devsite-nav-text" tooltip>Optimize GPU Performance</span></a></li> <li class="devsite-nav-item"><a href="/guide/graph_optimization" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/graph_optimization" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/graph_optimization" ><span class="devsite-nav-text" tooltip>Graph optimization</span></a></li> <li class="devsite-nav-item"><a href="/guide/mixed_precision" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/mixed_precision" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/mixed_precision" ><span class="devsite-nav-text" tooltip>Mixed precision</span></a></li> <li class="devsite-nav-item devsite-nav-heading devsite-nav-new"><div class="devsite-nav-title devsite-nav-title-no-path"> <span class="devsite-nav-text" tooltip>Model Garden</span><span class="devsite-nav-icon material-icons" data-icon="new" data-title="New!" aria-hidden="true"></span> </div></li> <li class="devsite-nav-item"><a href="/tfmodels" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /tfmodels" track-type="bookNav" track-name="click" track-metadata-eventdetail="/tfmodels" ><span class="devsite-nav-text" tooltip>Overview</span></a></li> <li class="devsite-nav-item"><a href="/tfmodels/orbit" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /tfmodels/orbit" track-type="bookNav" track-name="click" track-metadata-eventdetail="/tfmodels/orbit" ><span class="devsite-nav-text" tooltip>Training with Orbit</span></a></li> <li class="devsite-nav-item devsite-nav-external"><a href="/tfmodels/nlp" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /tfmodels/nlp" track-type="bookNav" track-name="click" track-metadata-eventdetail="/tfmodels/nlp" ><span class="devsite-nav-text" tooltip>TFModels - NLP</span><span class="devsite-nav-icon material-icons" data-icon="external" data-title="External" aria-hidden="true"></span></a></li> <li class="devsite-nav-item"><a href="/tfmodels/vision/image_classification" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /tfmodels/vision/image_classification" track-type="bookNav" track-name="click" track-metadata-eventdetail="/tfmodels/vision/image_classification" ><span class="devsite-nav-text" tooltip>Example: Image classification</span></a></li> <li class="devsite-nav-item"><a href="/tfmodels/vision/object_detection" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /tfmodels/vision/object_detection" track-type="bookNav" track-name="click" track-metadata-eventdetail="/tfmodels/vision/object_detection" ><span class="devsite-nav-text" tooltip>Example: Object Detection</span></a></li> <li class="devsite-nav-item"><a href="/tfmodels/vision/semantic_segmentation" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /tfmodels/vision/semantic_segmentation" track-type="bookNav" track-name="click" track-metadata-eventdetail="/tfmodels/vision/semantic_segmentation" ><span class="devsite-nav-text" tooltip>Example: Semantic Segmentation</span></a></li> <li class="devsite-nav-item"><a href="/tfmodels/vision/instance_segmentation" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /tfmodels/vision/instance_segmentation" track-type="bookNav" track-name="click" track-metadata-eventdetail="/tfmodels/vision/instance_segmentation" ><span class="devsite-nav-text" tooltip>Example: Instance Segmentation</span></a></li> <li class="devsite-nav-item devsite-nav-heading devsite-nav-deprecated"><div class="devsite-nav-title devsite-nav-title-no-path"> <span class="devsite-nav-text" tooltip>Estimators</span><span class="devsite-nav-icon material-icons" data-icon="deprecated" data-title="Deprecated" aria-hidden="true"></span> </div></li> <li class="devsite-nav-item"><a href="/guide/estimator" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/estimator" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/estimator" ><span class="devsite-nav-text" tooltip>Estimator overview</span></a></li> <li class="devsite-nav-item devsite-nav-heading"><div class="devsite-nav-title devsite-nav-title-no-path"> <span class="devsite-nav-text" tooltip>Appendix</span> </div></li> <li class="devsite-nav-item"><a href="/guide/versions" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /guide/versions" track-type="bookNav" track-name="click" track-metadata-eventdetail="/guide/versions" ><span class="devsite-nav-text" tooltip>Version compatibility</span></a></li> </ul> <ul class="devsite-nav-list" menu="Learn" aria-label="Side menu" hidden> <li class="devsite-nav-item"> <a href="/learn" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: Introduction" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Introduction </span> </a> </li> <li class="devsite-nav-item"> <a href="/tutorials" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: Tutorials" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Tutorials </span> </a> </li> <li class="devsite-nav-item"> <a href="/guide" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: Guide" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Guide </span> </a> </li> <li class="devsite-nav-item"> <a href="/resources/learn-ml" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: Learn ML" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Learn ML </span> </a> </li> </ul> <ul class="devsite-nav-list" menu="API" aria-label="Side menu" hidden> <li class="devsite-nav-item"> <a href="/api/stable" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: TensorFlow (v2.16.1)" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > TensorFlow (v2.16.1) </span> </a> </li> <li class="devsite-nav-item"> <a href="/versions" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: Versions…" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Versions… </span> </a> </li> <li class="devsite-nav-item"> <a href="https://js.tensorflow.org/api/latest/" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: TensorFlow.js" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > TensorFlow.js </span> </a> </li> <li class="devsite-nav-item"> <a href="/lite/api_docs" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: TensorFlow Lite" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > TensorFlow Lite </span> </a> </li> <li class="devsite-nav-item"> <a href="/tfx/api_docs" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: TFX" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > TFX </span> </a> </li> </ul> <ul class="devsite-nav-list" menu="Ecosystem" aria-label="Side menu" hidden> <li class="devsite-nav-item devsite-nav-heading"> <span class="devsite-nav-title" tooltip > <span class="devsite-nav-text" tooltip > LIBRARIES </span> </span> </li> <li class="devsite-nav-item"> <a href="/js" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: TensorFlow.js" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > TensorFlow.js </span> </a> </li> <li class="devsite-nav-item"> <a href="/lite" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: TensorFlow Lite" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > TensorFlow Lite </span> </a> </li> <li class="devsite-nav-item"> <a href="/tfx" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: TFX" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > TFX </span> </a> </li> <li class="devsite-nav-item"> <a href="/resources/libraries-extensions" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: All libraries" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > All libraries </span> </a> </li> <li class="devsite-nav-item devsite-nav-heading"> <span class="devsite-nav-title" tooltip > <span class="devsite-nav-text" tooltip > RESOURCES </span> </span> </li> <li class="devsite-nav-item"> <a href="/resources/models-datasets" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: Models & datasets" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Models & datasets </span> </a> </li> <li class="devsite-nav-item"> <a href="/resources/tools" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: Tools" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Tools </span> </a> </li> <li class="devsite-nav-item"> <a href="/responsible_ai" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: Responsible AI" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Responsible AI </span> </a> </li> <li class="devsite-nav-item"> <a href="/resources/recommendation-systems" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: Recommendation systems" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Recommendation systems </span> </a> </li> </ul> <ul class="devsite-nav-list" menu="Community" aria-label="Side menu" hidden> <li class="devsite-nav-item"> <a href="/community/groups" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: Groups" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Groups </span> </a> </li> <li class="devsite-nav-item"> <a href="/community/contribute" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: Contribute" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Contribute </span> </a> </li> <li class="devsite-nav-item"> <a href="https://blog.tensorflow.org/" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: Blog" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Blog </span> </a> </li> <li class="devsite-nav-item"> <a href="https://discuss.tensorflow.org" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: Forum" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Forum </span> </a> </li> </ul> <ul class="devsite-nav-list" menu="Why TensorFlow" aria-label="Side menu" hidden> <li class="devsite-nav-item"> <a href="/about" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: About" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > About </span> </a> </li> <li class="devsite-nav-item"> <a href="/about/case-studies" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: Case studies" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Case studies </span> </a> </li> </ul> </div> </div> </nav> </devsite-book-nav> <section id="gc-wrapper"> <main role="main" class="devsite-main-content" has-book-nav has-sidebar > <div class="devsite-sidebar"> <div class="devsite-sidebar-content"> <devsite-toc class="devsite-nav" role="navigation" aria-label="On this page" depth="2" scrollbars ></devsite-toc> <devsite-recommendations-sidebar class="nocontent devsite-nav"> </devsite-recommendations-sidebar> </div> </div> <devsite-content> <article class="devsite-article"><style> /* Styles inlined from /site-assets/css/style.css */ /* override theme */ table img { max-width: 100%; } /* .devsite-terminal virtualenv prompt */ .tfo-terminal-venv::before { content: "(venv) $ " !important; } /* .devsite-terminal root prompt */ .tfo-terminal-root::before { content: "# " !important; } /* Used in links for type annotations in function/method signatures */ .tfo-signature-link a, .tfo-signature-link a:visited, .tfo-signature-link a:hover, .tfo-signature-link a:focus, .tfo-signature-link a:hover *, .tfo-signature-link a:focus * { text-decoration: none !important; } .tfo-signature-link a, .tfo-signature-link a:visited { border-bottom: 1px dotted #1a73e8; } .tfo-signature-link a:focus { border-bottom-style: solid; } /* .devsite-terminal Windows prompt */ .tfo-terminal-windows::before { content: "C:\\> " !important; } /* .devsite-terminal Windows prompt w/ virtualenv */ .tfo-terminal-windows-venv::before { content: "(venv) C:\\> " !important; } .tfo-diff-green-one-level + * { background: rgba(175, 245, 162, .6) !important; } .tfo-diff-green + * > * { background: rgba(175, 245, 162, .6) !important; } .tfo-diff-green-list + ul > li:first-of-type { background: rgba(175, 245, 162, .6) !important; } .tfo-diff-red-one-level + * { background: rgba(255, 230, 230, .6) !important; text-decoration: line-through !important; } .tfo-diff-red + * > * { background: rgba(255, 230, 230, .6) !important; text-decoration: line-through !important; } .tfo-diff-red-list + ul > li:first-of-type { background: rgba(255, 230, 230, .6) !important; text-decoration: line-through !important; } devsite-code .tfo-notebook-code-cell-output { max-height: 300px; overflow: auto; background: rgba(255, 247, 237, 1); /* orange bg to distinguish from input code cells */ } devsite-code .tfo-notebook-code-cell-output + .devsite-code-buttons-container button { background: rgba(255, 247, 237, .7); /* orange bg to distinguish from input code cells */ } devsite-code[dark-code] .tfo-notebook-code-cell-output { background: rgba(64, 78, 103, 1); /* medium slate */ } devsite-code[dark-code] .tfo-notebook-code-cell-output + .devsite-code-buttons-container button { background: rgba(64, 78, 103, .7); /* medium slate */ } /* override default table styles for notebook buttons */ .devsite-table-wrapper .tfo-notebook-buttons { display: inline-block; margin-left: 3px; width: auto; } .tfo-notebook-buttons td { padding-left: 0; padding-right: 20px; } .tfo-notebook-buttons a, .tfo-notebook-buttons :link, .tfo-notebook-buttons :visited { border-radius: 8px; box-shadow: 0 1px 2px 0 rgba(60, 64, 67, .3), 0 1px 3px 1px rgba(60, 64, 67, .15); color: #202124; padding: 12px 17px; transition: box-shadow 0.2s; } .tfo-notebook-buttons a:hover, .tfo-notebook-buttons a:focus { box-shadow: 0 1px 2px 0 rgba(60, 64, 67, .3), 0 2px 6px 2px rgba(60, 64, 67, .15); } .tfo-notebook-buttons tr { background: 0; border: 0; } /* on rendered notebook page, remove link to webpage since we're already here */ .tfo-notebook-buttons:not(.tfo-api) td:first-child { display: none; } .tfo-notebook-buttons td > a { -webkit-box-align: center; -ms-flex-align: center; align-items: center; display: -webkit-box; display: -ms-flexbox; display: flex; } .tfo-notebook-buttons td > a > img { margin-right: 8px; } /* landing pages */ .tfo-landing-row-item-inset-white { background-color: #fff; padding: 32px; } .tfo-landing-row-item-inset-white ol, .tfo-landing-row-item-inset-white ul { padding-left: 20px; } /* colab callout button */ .colab-callout-row devsite-code { border-radius: 8px 8px 0 0; box-shadow: none; } .colab-callout-footer { background: #e3e4e7; border-radius: 0 0 8px 8px; color: #37474f; padding: 20px; } .colab-callout-row devsite-code[dark-code] + .colab-callout-footer { background: #3f4f66; } .colab-callout-footer > .button { margin-top: 4px; color: #ff5c00; } .colab-callout-footer > a > span { vertical-align: middle; color: #37474f; padding-left: 10px; font-size: 14px; } .colab-callout-row devsite-code[dark-code] + .colab-callout-footer > a > span { color: #fff; } a.colab-button { background: rgba(255, 255, 255, .75); border: solid 1px rgba(0, 0, 0, .08); border-bottom-color: rgba(0, 0, 0, .15); border-radius: 4px; color: #aaa; display: inline-block; font-size: 11px !important; font-weight: 300; line-height: 16px; padding: 4px 8px; text-decoration: none; text-transform: uppercase; } a.colab-button:hover { background: white; border-color: rgba(0, 0, 0, .2); color: #666; } a.colab-button span { background: url(/images/colab_logo_button.svg) no-repeat 1px 1px / 20px; border-radius: 4px; display: inline-block; padding-left: 24px; text-decoration: none; } @media screen and (max-width: 600px) { .tfo-notebook-buttons td { display: block; } } /* guide and tutorials landing page cards and sections */ .tfo-landing-page-card { padding: 16px; box-shadow: 0 0 36px rgba(0,0,0,0.1); border-radius: 10px; } /* Page section headings */ .tfo-landing-page-heading h2, h2.tfo-landing-page-heading { font-family: "Google Sans", sans-serif; color: #425066; font-size: 30px; font-weight: 700; line-height: 40px; } /* Item title headings */ .tfo-landing-page-heading h3, h3.tfo-landing-page-heading, .tfo-landing-page-card h3, h3.tfo-landing-page-card { font-family: "Google Sans", sans-serif; color: #425066; font-size: 20px; font-weight: 500; line-height: 26px; } /* List of tutorials notebooks for subsites */ .tfo-landing-page-resources-ul { padding-left: 15px } .tfo-landing-page-resources-ul > li { margin: 6px 0; } /* Temporary fix to hide product description in header on landing pages */ devsite-header .devsite-product-description { display: none; } </style> <div class="devsite-article-meta nocontent" role="navigation"> <ul class="devsite-breadcrumb-list" aria-label="Breadcrumb"> <li class="devsite-breadcrumb-item "> <a href="https://www.tensorflow.org/" class="devsite-breadcrumb-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Breadcrumbs" data-value="1" track-type="globalNav" track-name="breadcrumb" track-metadata-position="1" track-metadata-eventdetail="TensorFlow" > TensorFlow </a> </li> <li class="devsite-breadcrumb-item "> <div class="devsite-breadcrumb-guillemet material-icons" aria-hidden="true"></div> <a href="https://www.tensorflow.org/learn" class="devsite-breadcrumb-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Breadcrumbs" data-value="2" track-type="globalNav" track-name="breadcrumb" track-metadata-position="2" track-metadata-eventdetail="" > Learn </a> </li> <li class="devsite-breadcrumb-item "> <div class="devsite-breadcrumb-guillemet material-icons" aria-hidden="true"></div> <a href="https://www.tensorflow.org/tutorials" class="devsite-breadcrumb-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Breadcrumbs" data-value="3" track-type="globalNav" track-name="breadcrumb" track-metadata-position="3" track-metadata-eventdetail="TensorFlow Core" > TensorFlow Core </a> </li> <li class="devsite-breadcrumb-item "> <div class="devsite-breadcrumb-guillemet material-icons" aria-hidden="true"></div> <a href="https://www.tensorflow.org/guide" class="devsite-breadcrumb-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Breadcrumbs" data-value="4" track-type="globalNav" track-name="breadcrumb" track-metadata-position="4" track-metadata-eventdetail="" > Guide </a> </li> </ul> <devsite-thumb-rating position="header"> </devsite-thumb-rating> </div> <h1 class="devsite-page-title" tabindex="-1"> The Functional API </h1> <devsite-feature-tooltip ack-key="AckCollectionsBookmarkTooltipDismiss" analytics-category="Site-Wide Custom Events" analytics-action-show="Callout Profile displayed" analytics-action-close="Callout Profile dismissed" analytics-label="Create Collection Callout" class="devsite-page-bookmark-tooltip nocontent" dismiss-button="true" id="devsite-collections-dropdown" dismiss-button-text="Dismiss" close-button-text="Got it"> <devsite-bookmark></devsite-bookmark> <span slot="popout-heading"> Stay organized with collections </span> <span slot="popout-contents"> Save and categorize content based on your preferences. </span> </devsite-feature-tooltip> <div class="devsite-page-title-meta"><devsite-view-release-notes></devsite-view-release-notes></div> <devsite-toc class="devsite-nav" depth="2" devsite-toc-embedded > </devsite-toc> <div class="devsite-article-body clearfix "> <p></p> <!-- DO NOT EDIT! Automatically generated file. --> <div itemscope itemtype="http://developers.google.com/ReferenceObject"> <meta itemprop="name" content="The Functional API" /> <meta itemprop="path" content="Guide & Tutorials" /> <meta itemprop="property" content="tf.linalg.matmul"/> <meta itemprop="property" content="tf.stack"/> <meta itemprop="property" content="tf.zeros"/> </div> <p><strong>Author:</strong> <a href="https://twitter.com/fchollet">fchollet</a><br></p> <table class="tfo-notebook-buttons" align="left"> <td> <a target="_blank" href="https://www.tensorflow.org/guide/keras/functional_api"><img src="https://www.tensorflow.org/images/tf_logo_32px.png">View on TensorFlow.org</a> </td> <td> <a target="_blank" href="https://colab.research.google.com/github/keras-team/keras-io/blob/master/guides/ipynb/functional_api.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png">Run in Google Colab</a> </td> <td> <a target="_blank" href="https://github.com/keras-team/keras-io/blob/master/guides/functional_api.py"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png">View source on GitHub</a> </td> <td> <a href="https://keras.io/guides/functional_api"><img src="https://www.tensorflow.org/images/keras32px.png">View on keras.io</a> </td> </table> <h2 id="setup" data-text="Setup" tabindex="-1">Setup</h2> <pre class="prettyprint lang-python" translate="no" dir="ltr"><code translate="no" dir="ltr">import numpy as np import tensorflow as tf from tensorflow import keras from keras import layers </code></pre> <h2 id="introduction" data-text="Introduction" tabindex="-1">Introduction</h2> <p>The Keras <em>functional API</em> is a way to create models that are more flexible than the <a href="https://www.tensorflow.org/api_docs/python/tf/keras/Sequential"><code translate="no" dir="ltr">keras.Sequential</code></a> API. The functional API can handle models with non-linear topology, shared layers, and even multiple inputs or outputs.</p> <p>The main idea is that a deep learning model is usually a directed acyclic graph (DAG) of layers. So the functional API is a way to build <em>graphs of layers</em>.</p> <p>Consider the following model:</p> <div class="k-default-codeblock"> ``` (input: 784-dimensional vectors) ↧ [Dense (64 units, relu activation)] ↧ [Dense (64 units, relu activation)] ↧ [Dense (10 units, softmax activation)] ↧ (output: logits of a probability distribution over 10 classes) ``` </div> <p>This is a basic graph with three layers. To build this model using the functional API, start by creating an input node:</p> <pre class="prettyprint lang-python" translate="no" dir="ltr"><code translate="no" dir="ltr">inputs = keras.Input(shape=(784,)) </code></pre> <p>The shape of the data is set as a 784-dimensional vector. The batch size is always omitted since only the shape of each sample is specified.</p> <p>If, for example, you have an image input with a shape of <code translate="no" dir="ltr">(32, 32, 3)</code>, you would use:</p> <pre class="prettyprint lang-python" translate="no" dir="ltr"><code translate="no" dir="ltr"># Just for demonstration purposes. img_inputs = keras.Input(shape=(32, 32, 3)) </code></pre> <p>The <code translate="no" dir="ltr">inputs</code> that is returned contains information about the shape and <code translate="no" dir="ltr">dtype</code> of the input data that you feed to your model. Here's the shape:</p> <pre class="prettyprint lang-python" translate="no" dir="ltr"><code translate="no" dir="ltr">inputs.shape </code></pre> <pre class="tfo-notebook-code-cell-output" translate="no" dir="ltr"> TensorShape([None, 784]) </pre> <p>Here's the dtype:</p> <pre class="prettyprint lang-python" translate="no" dir="ltr"><code translate="no" dir="ltr">inputs.dtype </code></pre> <pre class="tfo-notebook-code-cell-output" translate="no" dir="ltr"> tf.float32 </pre> <p>You create a new node in the graph of layers by calling a layer on this <code translate="no" dir="ltr">inputs</code> object:</p> <pre class="prettyprint lang-python" translate="no" dir="ltr"><code translate="no" dir="ltr">dense = layers.Dense(64, activation="relu") x = dense(inputs) </code></pre> <p>The "layer call" action is like drawing an arrow from "inputs" to this layer you created. You're "passing" the inputs to the <code translate="no" dir="ltr">dense</code> layer, and you get <code translate="no" dir="ltr">x</code> as the output.</p> <p>Let's add a few more layers to the graph of layers:</p> <pre class="prettyprint lang-python" translate="no" dir="ltr"><code translate="no" dir="ltr">x = layers.Dense(64, activation="relu")(x) outputs = layers.Dense(10)(x) </code></pre> <p>At this point, you can create a <code translate="no" dir="ltr">Model</code> by specifying its inputs and outputs in the graph of layers:</p> <pre class="prettyprint lang-python" translate="no" dir="ltr"><code translate="no" dir="ltr">model = keras.Model(inputs=inputs, outputs=outputs, name="mnist_model") </code></pre> <p>Let's check out what the model summary looks like:</p> <pre class="prettyprint lang-python" translate="no" dir="ltr"><code translate="no" dir="ltr">model.summary() </code></pre> <pre class="tfo-notebook-code-cell-output" translate="no" dir="ltr"> Model: "mnist_model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_1 (InputLayer) [(None, 784)] 0 dense (Dense) (None, 64) 50240 dense_1 (Dense) (None, 64) 4160 dense_2 (Dense) (None, 10) 650 ================================================================= Total params: 55050 (215.04 KB) Trainable params: 55050 (215.04 KB) Non-trainable params: 0 (0.00 Byte) _________________________________________________________________ </pre> <p>You can also plot the model as a graph:</p> <pre class="prettyprint lang-python" translate="no" dir="ltr"><code translate="no" dir="ltr">keras.utils.plot_model(model, "my_first_model.png") </code></pre> <p><img src="/static/guide/keras/functional_api_files/output_functional_api_23_0.png" alt="png"></p> <p>And, optionally, display the input and output shapes of each layer in the plotted graph:</p> <pre class="prettyprint lang-python" translate="no" dir="ltr"><code translate="no" dir="ltr">keras.utils.plot_model(model, "my_first_model_with_shape_info.png", show_shapes=True) </code></pre> <p><img src="/static/guide/keras/functional_api_files/output_functional_api_25_0.png" alt="png"></p> <p>This figure and the code are almost identical. In the code version, the connection arrows are replaced by the call operation.</p> <p>A "graph of layers" is an intuitive mental image for a deep learning model, and the functional API is a way to create models that closely mirrors this.</p> <h2 id="training_evaluation_and_inference" data-text="Training, evaluation, and inference" tabindex="-1">Training, evaluation, and inference</h2> <p>Training, evaluation, and inference work exactly in the same way for models built using the functional API as for <code translate="no" dir="ltr">Sequential</code> models.</p> <p>The <code translate="no" dir="ltr">Model</code> class offers a built-in training loop (the <code translate="no" dir="ltr">fit()</code> method) and a built-in evaluation loop (the <code translate="no" dir="ltr">evaluate()</code> method). Note that you can easily <a href="/guide/keras/customizing_what_happens_in_fit">customize these loops</a> to implement training routines beyond supervised learning (e.g. <a href="https://keras.io/examples/generative/dcgan_overriding_train_step/">GANs</a>).</p> <p>Here, load the MNIST image data, reshape it into vectors, fit the model on the data (while monitoring performance on a validation split), then evaluate the model on the test data:</p> <pre class="prettyprint lang-python" translate="no" dir="ltr"><code translate="no" dir="ltr">(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() x_train = x_train.reshape(60000, 784).astype("float32") / 255 x_test = x_test.reshape(10000, 784).astype("float32") / 255 model.compile( loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer=keras.optimizers.RMSprop(), metrics=[keras.metrics.SparseCategoricalAccuracy()], ) history = model.fit(x_train, y_train, batch_size=64, epochs=2, validation_split=0.2) test_scores = model.evaluate(x_test, y_test, verbose=2) print("Test loss:", test_scores[0]) print("Test accuracy:", test_scores[1]) </code></pre> <pre class="tfo-notebook-code-cell-output" translate="no" dir="ltr"> Epoch 1/2 750/750 [==============================] - 4s 3ms/step - loss: 0.3556 - sparse_categorical_accuracy: 0.8971 - val_loss: 0.1962 - val_sparse_categorical_accuracy: 0.9422 Epoch 2/2 750/750 [==============================] - 2s 2ms/step - loss: 0.1612 - sparse_categorical_accuracy: 0.9527 - val_loss: 0.1461 - val_sparse_categorical_accuracy: 0.9592 313/313 - 0s - loss: 0.1492 - sparse_categorical_accuracy: 0.9556 - 463ms/epoch - 1ms/step Test loss: 0.14915992319583893 Test accuracy: 0.9556000232696533 </pre> <p>For further reading, see the <a href="/guide/keras/training_with_built_in_methods">training and evaluation</a> guide.</p> <h2 id="save_and_serialize" data-text="Save and serialize" tabindex="-1">Save and serialize</h2> <p>Saving the model and serialization work the same way for models built using the functional API as they do for <code translate="no" dir="ltr">Sequential</code> models. The standard way to save a functional model is to call <code translate="no" dir="ltr">model.save()</code> to save the entire model as a single file. You can later recreate the same model from this file, even if the code that built the model is no longer available.</p> <p>This saved file includes the:</p> <ul> <li>model architecture</li> <li>model weight values (that were learned during training)</li> <li>model training config, if any (as passed to <code translate="no" dir="ltr">compile</code>)</li> <li>optimizer and its state, if any (to restart training where you left off)</li> </ul> <pre class="prettyprint lang-python" translate="no" dir="ltr"><code translate="no" dir="ltr">model.save("path_to_my_model.keras") del model # Recreate the exact same model purely from the file: model = keras.models.load_model("path_to_my_model.keras") </code></pre> <p>For details, read the model <a href="/guides/serialization_and_saving">serialization & saving</a> guide.</p> <h2 id="use_the_same_graph_of_layers_to_define_multiple_models" data-text="Use the same graph of layers to define multiple models" tabindex="-1">Use the same graph of layers to define multiple models</h2> <p>In the functional API, models are created by specifying their inputs and outputs in a graph of layers. That means that a single graph of layers can be used to generate multiple models.</p> <p>In the example below, you use the same stack of layers to instantiate two models: an <code translate="no" dir="ltr">encoder</code> model that turns image inputs into 16-dimensional vectors, and an end-to-end <code translate="no" dir="ltr">autoencoder</code> model for training.</p> <pre class="prettyprint lang-python" translate="no" dir="ltr"><code translate="no" dir="ltr">encoder_input = keras.Input(shape=(28, 28, 1), name="img") x = layers.Conv2D(16, 3, activation="relu")(encoder_input) x = layers.Conv2D(32, 3, activation="relu")(x) x = layers.MaxPooling2D(3)(x) x = layers.Conv2D(32, 3, activation="relu")(x) x = layers.Conv2D(16, 3, activation="relu")(x) encoder_output = layers.GlobalMaxPooling2D()(x) encoder = keras.Model(encoder_input, encoder_output, name="encoder") encoder.summary() x = layers.Reshape((4, 4, 1))(encoder_output) x = layers.Conv2DTranspose(16, 3, activation="relu")(x) x = layers.Conv2DTranspose(32, 3, activation="relu")(x) x = layers.UpSampling2D(3)(x) x = layers.Conv2DTranspose(16, 3, activation="relu")(x) decoder_output = layers.Conv2DTranspose(1, 3, activation="relu")(x) autoencoder = keras.Model(encoder_input, decoder_output, name="autoencoder") autoencoder.summary() </code></pre> <pre class="tfo-notebook-code-cell-output" translate="no" dir="ltr"> Model: "encoder" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= img (InputLayer) [(None, 28, 28, 1)] 0 conv2d (Conv2D) (None, 26, 26, 16) 160 conv2d_1 (Conv2D) (None, 24, 24, 32) 4640 max_pooling2d (MaxPooling2 (None, 8, 8, 32) 0 D) conv2d_2 (Conv2D) (None, 6, 6, 32) 9248 conv2d_3 (Conv2D) (None, 4, 4, 16) 4624 global_max_pooling2d (Glob (None, 16) 0 alMaxPooling2D) ================================================================= Total params: 18672 (72.94 KB) Trainable params: 18672 (72.94 KB) Non-trainable params: 0 (0.00 Byte) _________________________________________________________________ Model: "autoencoder" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= img (InputLayer) [(None, 28, 28, 1)] 0 conv2d (Conv2D) (None, 26, 26, 16) 160 conv2d_1 (Conv2D) (None, 24, 24, 32) 4640 max_pooling2d (MaxPooling2 (None, 8, 8, 32) 0 D) conv2d_2 (Conv2D) (None, 6, 6, 32) 9248 conv2d_3 (Conv2D) (None, 4, 4, 16) 4624 global_max_pooling2d (Glob (None, 16) 0 alMaxPooling2D) reshape (Reshape) (None, 4, 4, 1) 0 conv2d_transpose (Conv2DTr (None, 6, 6, 16) 160 anspose) conv2d_transpose_1 (Conv2D (None, 8, 8, 32) 4640 Transpose) up_sampling2d (UpSampling2 (None, 24, 24, 32) 0 D) conv2d_transpose_2 (Conv2D (None, 26, 26, 16) 4624 Transpose) conv2d_transpose_3 (Conv2D (None, 28, 28, 1) 145 Transpose) ================================================================= Total params: 28241 (110.32 KB) Trainable params: 28241 (110.32 KB) Non-trainable params: 0 (0.00 Byte) _________________________________________________________________ </pre> <p>Here, the decoding architecture is strictly symmetrical to the encoding architecture, so the output shape is the same as the input shape <code translate="no" dir="ltr">(28, 28, 1)</code>.</p> <p>The reverse of a <code translate="no" dir="ltr">Conv2D</code> layer is a <code translate="no" dir="ltr">Conv2DTranspose</code> layer, and the reverse of a <code translate="no" dir="ltr">MaxPooling2D</code> layer is an <code translate="no" dir="ltr">UpSampling2D</code> layer.</p> <h2 id="all_models_are_callable_just_like_layers" data-text="All models are callable, just like layers" tabindex="-1">All models are callable, just like layers</h2> <p>You can treat any model as if it were a layer by invoking it on an <code translate="no" dir="ltr">Input</code> or on the output of another layer. By calling a model you aren't just reusing the architecture of the model, you're also reusing its weights.</p> <p>To see this in action, here's a different take on the autoencoder example that creates an encoder model, a decoder model, and chains them in two calls to obtain the autoencoder model:</p> <pre class="prettyprint lang-python" translate="no" dir="ltr"><code translate="no" dir="ltr">encoder_input = keras.Input(shape=(28, 28, 1), name="original_img") x = layers.Conv2D(16, 3, activation="relu")(encoder_input) x = layers.Conv2D(32, 3, activation="relu")(x) x = layers.MaxPooling2D(3)(x) x = layers.Conv2D(32, 3, activation="relu")(x) x = layers.Conv2D(16, 3, activation="relu")(x) encoder_output = layers.GlobalMaxPooling2D()(x) encoder = keras.Model(encoder_input, encoder_output, name="encoder") encoder.summary() decoder_input = keras.Input(shape=(16,), name="encoded_img") x = layers.Reshape((4, 4, 1))(decoder_input) x = layers.Conv2DTranspose(16, 3, activation="relu")(x) x = layers.Conv2DTranspose(32, 3, activation="relu")(x) x = layers.UpSampling2D(3)(x) x = layers.Conv2DTranspose(16, 3, activation="relu")(x) decoder_output = layers.Conv2DTranspose(1, 3, activation="relu")(x) decoder = keras.Model(decoder_input, decoder_output, name="decoder") decoder.summary() autoencoder_input = keras.Input(shape=(28, 28, 1), name="img") encoded_img = encoder(autoencoder_input) decoded_img = decoder(encoded_img) autoencoder = keras.Model(autoencoder_input, decoded_img, name="autoencoder") autoencoder.summary() </code></pre> <pre class="tfo-notebook-code-cell-output" translate="no" dir="ltr"> Model: "encoder" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= original_img (InputLayer) [(None, 28, 28, 1)] 0 conv2d_4 (Conv2D) (None, 26, 26, 16) 160 conv2d_5 (Conv2D) (None, 24, 24, 32) 4640 max_pooling2d_1 (MaxPoolin (None, 8, 8, 32) 0 g2D) conv2d_6 (Conv2D) (None, 6, 6, 32) 9248 conv2d_7 (Conv2D) (None, 4, 4, 16) 4624 global_max_pooling2d_1 (Gl (None, 16) 0 obalMaxPooling2D) ================================================================= Total params: 18672 (72.94 KB) Trainable params: 18672 (72.94 KB) Non-trainable params: 0 (0.00 Byte) _________________________________________________________________ Model: "decoder" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= encoded_img (InputLayer) [(None, 16)] 0 reshape_1 (Reshape) (None, 4, 4, 1) 0 conv2d_transpose_4 (Conv2D (None, 6, 6, 16) 160 Transpose) conv2d_transpose_5 (Conv2D (None, 8, 8, 32) 4640 Transpose) up_sampling2d_1 (UpSamplin (None, 24, 24, 32) 0 g2D) conv2d_transpose_6 (Conv2D (None, 26, 26, 16) 4624 Transpose) conv2d_transpose_7 (Conv2D (None, 28, 28, 1) 145 Transpose) ================================================================= Total params: 9569 (37.38 KB) Trainable params: 9569 (37.38 KB) Non-trainable params: 0 (0.00 Byte) _________________________________________________________________ Model: "autoencoder" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= img (InputLayer) [(None, 28, 28, 1)] 0 encoder (Functional) (None, 16) 18672 decoder (Functional) (None, 28, 28, 1) 9569 ================================================================= Total params: 28241 (110.32 KB) Trainable params: 28241 (110.32 KB) Non-trainable params: 0 (0.00 Byte) _________________________________________________________________ </pre> <p>As you can see, the model can be nested: a model can contain sub-models (since a model is just like a layer). A common use case for model nesting is <em>ensembling</em>. For example, here's how to ensemble a set of models into a single model that averages their predictions:</p> <pre class="prettyprint lang-python" translate="no" dir="ltr"><code translate="no" dir="ltr">def get_model(): inputs = keras.Input(shape=(128,)) outputs = layers.Dense(1)(inputs) return keras.Model(inputs, outputs) model1 = get_model() model2 = get_model() model3 = get_model() inputs = keras.Input(shape=(128,)) y1 = model1(inputs) y2 = model2(inputs) y3 = model3(inputs) outputs = layers.average([y1, y2, y3]) ensemble_model = keras.Model(inputs=inputs, outputs=outputs) </code></pre> <h2 id="manipulate_complex_graph_topologies" data-text="Manipulate complex graph topologies" tabindex="-1">Manipulate complex graph topologies</h2> <h3 id="models_with_multiple_inputs_and_outputs" data-text="Models with multiple inputs and outputs" tabindex="-1">Models with multiple inputs and outputs</h3> <p>The functional API makes it easy to manipulate multiple inputs and outputs. This cannot be handled with the <code translate="no" dir="ltr">Sequential</code> API.</p> <p>For example, if you're building a system for ranking customer issue tickets by priority and routing them to the correct department, then the model will have three inputs:</p> <ul> <li>the title of the ticket (text input),</li> <li>the text body of the ticket (text input), and</li> <li>any tags added by the user (categorical input)</li> </ul> <p>This model will have two outputs:</p> <ul> <li>the priority score between 0 and 1 (scalar sigmoid output), and</li> <li>the department that should handle the ticket (softmax output over the set of departments).</li> </ul> <p>You can build this model in a few lines with the functional API:</p> <pre class="prettyprint lang-python" translate="no" dir="ltr"><code translate="no" dir="ltr">num_tags = 12 # Number of unique issue tags num_words = 10000 # Size of vocabulary obtained when preprocessing text data num_departments = 4 # Number of departments for predictions title_input = keras.Input( shape=(None,), name="title" ) # Variable-length sequence of ints body_input = keras.Input(shape=(None,), name="body") # Variable-length sequence of ints tags_input = keras.Input( shape=(num_tags,), name="tags" ) # Binary vectors of size `num_tags` # Embed each word in the title into a 64-dimensional vector title_features = layers.Embedding(num_words, 64)(title_input) # Embed each word in the text into a 64-dimensional vector body_features = layers.Embedding(num_words, 64)(body_input) # Reduce sequence of embedded words in the title into a single 128-dimensional vector title_features = layers.LSTM(128)(title_features) # Reduce sequence of embedded words in the body into a single 32-dimensional vector body_features = layers.LSTM(32)(body_features) # Merge all available features into a single large vector via concatenation x = layers.concatenate([title_features, body_features, tags_input]) # Stick a logistic regression for priority prediction on top of the features priority_pred = layers.Dense(1, name="priority")(x) # Stick a department classifier on top of the features department_pred = layers.Dense(num_departments, name="department")(x) # Instantiate an end-to-end model predicting both priority and department model = keras.Model( inputs=[title_input, body_input, tags_input], outputs=[priority_pred, department_pred], ) </code></pre> <p>Now plot the model:</p> <pre class="prettyprint lang-python" translate="no" dir="ltr"><code translate="no" dir="ltr">keras.utils.plot_model(model, "multi_input_and_output_model.png", show_shapes=True) </code></pre> <p><img src="/static/guide/keras/functional_api_files/output_functional_api_43_0.png" alt="png"></p> <p>When compiling this model, you can assign different losses to each output. You can even assign different weights to each loss -- to modulate their contribution to the total training loss.</p> <pre class="prettyprint lang-python" translate="no" dir="ltr"><code translate="no" dir="ltr">model.compile( optimizer=keras.optimizers.RMSprop(1e-3), loss=[ keras.losses.BinaryCrossentropy(from_logits=True), keras.losses.CategoricalCrossentropy(from_logits=True), ], loss_weights=[1.0, 0.2], ) </code></pre> <p>Since the output layers have different names, you could also specify the losses and loss weights with the corresponding layer names:</p> <pre class="prettyprint lang-python" translate="no" dir="ltr"><code translate="no" dir="ltr">model.compile( optimizer=keras.optimizers.RMSprop(1e-3), loss={ "priority": keras.losses.BinaryCrossentropy(from_logits=True), "department": keras.losses.CategoricalCrossentropy(from_logits=True), }, loss_weights={"priority": 1.0, "department": 0.2}, ) </code></pre> <p>Train the model by passing lists of NumPy arrays of inputs and targets:</p> <pre class="prettyprint lang-python" translate="no" dir="ltr"><code translate="no" dir="ltr"># Dummy input data title_data = np.random.randint(num_words, size=(1280, 10)) body_data = np.random.randint(num_words, size=(1280, 100)) tags_data = np.random.randint(2, size=(1280, num_tags)).astype("float32") # Dummy target data priority_targets = np.random.random(size=(1280, 1)) dept_targets = np.random.randint(2, size=(1280, num_departments)) model.fit( {"title": title_data, "body": body_data, "tags": tags_data}, {"priority": priority_targets, "department": dept_targets}, epochs=2, batch_size=32, ) </code></pre> <pre class="tfo-notebook-code-cell-output" translate="no" dir="ltr"> Epoch 1/2 40/40 [==============================] - 8s 112ms/step - loss: 1.2982 - priority_loss: 0.6991 - department_loss: 2.9958 Epoch 2/2 40/40 [==============================] - 3s 64ms/step - loss: 1.3110 - priority_loss: 0.6977 - department_loss: 3.0666 <keras.src.callbacks.History at 0x7f08d51fab80> </pre> <p>When calling fit with a <code translate="no" dir="ltr">Dataset</code> object, it should yield either a tuple of lists like <code translate="no" dir="ltr">([title_data, body_data, tags_data], [priority_targets, dept_targets])</code> or a tuple of dictionaries like <code translate="no" dir="ltr">({'title': title_data, 'body': body_data, 'tags': tags_data}, {'priority': priority_targets, 'department': dept_targets})</code>.</p> <p>For more detailed explanation, refer to the <a href="/guide/keras/training_with_built_in_methods">training and evaluation</a> guide.</p> <h3 id="a_toy_resnet_model" data-text="A toy ResNet model" tabindex="-1">A toy ResNet model</h3> <p>In addition to models with multiple inputs and outputs, the functional API makes it easy to manipulate non-linear connectivity topologies -- these are models with layers that are not connected sequentially, which the <code translate="no" dir="ltr">Sequential</code> API cannot handle.</p> <p>A common use case for this is residual connections. Let's build a toy ResNet model for CIFAR10 to demonstrate this:</p> <pre class="prettyprint lang-python" translate="no" dir="ltr"><code translate="no" dir="ltr">inputs = keras.Input(shape=(32, 32, 3), name="img") x = layers.Conv2D(32, 3, activation="relu")(inputs) x = layers.Conv2D(64, 3, activation="relu")(x) block_1_output = layers.MaxPooling2D(3)(x) x = layers.Conv2D(64, 3, activation="relu", padding="same")(block_1_output) x = layers.Conv2D(64, 3, activation="relu", padding="same")(x) block_2_output = layers.add([x, block_1_output]) x = layers.Conv2D(64, 3, activation="relu", padding="same")(block_2_output) x = layers.Conv2D(64, 3, activation="relu", padding="same")(x) block_3_output = layers.add([x, block_2_output]) x = layers.Conv2D(64, 3, activation="relu")(block_3_output) x = layers.GlobalAveragePooling2D()(x) x = layers.Dense(256, activation="relu")(x) x = layers.Dropout(0.5)(x) outputs = layers.Dense(10)(x) model = keras.Model(inputs, outputs, name="toy_resnet") model.summary() </code></pre> <pre class="tfo-notebook-code-cell-output" translate="no" dir="ltr"> Model: "toy_resnet" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== img (InputLayer) [(None, 32, 32, 3)] 0 [] conv2d_8 (Conv2D) (None, 30, 30, 32) 896 ['img[0][0]'] conv2d_9 (Conv2D) (None, 28, 28, 64) 18496 ['conv2d_8[0][0]'] max_pooling2d_2 (MaxPoolin (None, 9, 9, 64) 0 ['conv2d_9[0][0]'] g2D) conv2d_10 (Conv2D) (None, 9, 9, 64) 36928 ['max_pooling2d_2[0][0]'] conv2d_11 (Conv2D) (None, 9, 9, 64) 36928 ['conv2d_10[0][0]'] add (Add) (None, 9, 9, 64) 0 ['conv2d_11[0][0]', 'max_pooling2d_2[0][0]'] conv2d_12 (Conv2D) (None, 9, 9, 64) 36928 ['add[0][0]'] conv2d_13 (Conv2D) (None, 9, 9, 64) 36928 ['conv2d_12[0][0]'] add_1 (Add) (None, 9, 9, 64) 0 ['conv2d_13[0][0]', 'add[0][0]'] conv2d_14 (Conv2D) (None, 7, 7, 64) 36928 ['add_1[0][0]'] global_average_pooling2d ( (None, 64) 0 ['conv2d_14[0][0]'] GlobalAveragePooling2D) dense_6 (Dense) (None, 256) 16640 ['global_average_pooling2d[0][ 0]'] dropout (Dropout) (None, 256) 0 ['dense_6[0][0]'] dense_7 (Dense) (None, 10) 2570 ['dropout[0][0]'] ================================================================================================== Total params: 223242 (872.04 KB) Trainable params: 223242 (872.04 KB) Non-trainable params: 0 (0.00 Byte) __________________________________________________________________________________________________ </pre> <p>Plot the model:</p> <pre class="prettyprint lang-python" translate="no" dir="ltr"><code translate="no" dir="ltr">keras.utils.plot_model(model, "mini_resnet.png", show_shapes=True) </code></pre> <p><img src="/static/guide/keras/functional_api_files/output_functional_api_54_0.png" alt="png"></p> <p>Now train the model:</p> <pre class="prettyprint lang-python" translate="no" dir="ltr"><code translate="no" dir="ltr">(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data() x_train = x_train.astype("float32") / 255.0 x_test = x_test.astype("float32") / 255.0 y_train = keras.utils.to_categorical(y_train, 10) y_test = keras.utils.to_categorical(y_test, 10) model.compile( optimizer=keras.optimizers.RMSprop(1e-3), loss=keras.losses.CategoricalCrossentropy(from_logits=True), metrics=["acc"], ) # We restrict the data to the first 1000 samples so as to limit execution time # on Colab. Try to train on the entire dataset until convergence! model.fit(x_train[:1000], y_train[:1000], batch_size=64, epochs=1, validation_split=0.2) </code></pre> <pre class="tfo-notebook-code-cell-output" translate="no" dir="ltr"> 13/13 [==============================] - 4s 39ms/step - loss: 2.3086 - acc: 0.0988 - val_loss: 2.3020 - val_acc: 0.0850 <keras.src.callbacks.History at 0x7f078810c880> </pre> <h2 id="shared_layers" data-text="Shared layers" tabindex="-1">Shared layers</h2> <p>Another good use for the functional API are models that use <em>shared layers</em>. Shared layers are layer instances that are reused multiple times in the same model -- they learn features that correspond to multiple paths in the graph-of-layers.</p> <p>Shared layers are often used to encode inputs from similar spaces (say, two different pieces of text that feature similar vocabulary). They enable sharing of information across these different inputs, and they make it possible to train such a model on less data. If a given word is seen in one of the inputs, that will benefit the processing of all inputs that pass through the shared layer.</p> <p>To share a layer in the functional API, call the same layer instance multiple times. For instance, here's an <code translate="no" dir="ltr">Embedding</code> layer shared across two different text inputs:</p> <pre class="prettyprint lang-python" translate="no" dir="ltr"><code translate="no" dir="ltr"># Embedding for 1000 unique words mapped to 128-dimensional vectors shared_embedding = layers.Embedding(1000, 128) # Variable-length sequence of integers text_input_a = keras.Input(shape=(None,), dtype="int32") # Variable-length sequence of integers text_input_b = keras.Input(shape=(None,), dtype="int32") # Reuse the same layer to encode both inputs encoded_input_a = shared_embedding(text_input_a) encoded_input_b = shared_embedding(text_input_b) </code></pre> <h2 id="extract_and_reuse_nodes_in_the_graph_of_layers" data-text="Extract and reuse nodes in the graph of layers" tabindex="-1">Extract and reuse nodes in the graph of layers</h2> <p>Because the graph of layers you are manipulating is a static data structure, it can be accessed and inspected. And this is how you are able to plot functional models as images.</p> <p>This also means that you can access the activations of intermediate layers ("nodes" in the graph) and reuse them elsewhere -- which is very useful for something like feature extraction.</p> <p>Let's look at an example. This is a VGG19 model with weights pretrained on ImageNet:</p> <pre class="prettyprint lang-python" translate="no" dir="ltr"><code translate="no" dir="ltr">vgg19 = keras.applications.VGG19() </code></pre> <pre class="tfo-notebook-code-cell-output" translate="no" dir="ltr"> Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg19/vgg19_weights_tf_dim_ordering_tf_kernels.h5 574710816/574710816 [==============================] - 4s 0us/step </pre> <p>And these are the intermediate activations of the model, obtained by querying the graph data structure:</p> <pre class="prettyprint lang-python" translate="no" dir="ltr"><code translate="no" dir="ltr">features_list = [layer.output for layer in vgg19.layers] </code></pre> <p>Use these features to create a new feature-extraction model that returns the values of the intermediate layer activations:</p> <pre class="prettyprint lang-python" translate="no" dir="ltr"><code translate="no" dir="ltr">feat_extraction_model = keras.Model(inputs=vgg19.input, outputs=features_list) img = np.random.random((1, 224, 224, 3)).astype("float32") extracted_features = feat_extraction_model(img) </code></pre> <p>This comes in handy for tasks like <a href="https://keras.io/examples/generative/neural_style_transfer/">neural style transfer</a>, among other things.</p> <h2 id="extend_the_api_using_custom_layers" data-text="Extend the API using custom layers" tabindex="-1">Extend the API using custom layers</h2> <p><code translate="no" dir="ltr">keras</code> includes a wide range of built-in layers, for example:</p> <ul> <li>Convolutional layers: <code translate="no" dir="ltr">Conv1D</code>, <code translate="no" dir="ltr">Conv2D</code>, <code translate="no" dir="ltr">Conv3D</code>, <code translate="no" dir="ltr">Conv2DTranspose</code></li> <li>Pooling layers: <code translate="no" dir="ltr">MaxPooling1D</code>, <code translate="no" dir="ltr">MaxPooling2D</code>, <code translate="no" dir="ltr">MaxPooling3D</code>, <code translate="no" dir="ltr">AveragePooling1D</code></li> <li>RNN layers: <code translate="no" dir="ltr">GRU</code>, <code translate="no" dir="ltr">LSTM</code>, <code translate="no" dir="ltr">ConvLSTM2D</code></li> <li><code translate="no" dir="ltr">BatchNormalization</code>, <code translate="no" dir="ltr">Dropout</code>, <code translate="no" dir="ltr">Embedding</code>, etc.</li> </ul> <p>But if you don't find what you need, it's easy to extend the API by creating your own layers. All layers subclass the <code translate="no" dir="ltr">Layer</code> class and implement:</p> <ul> <li><code translate="no" dir="ltr">call</code> method, that specifies the computation done by the layer.</li> <li><code translate="no" dir="ltr">build</code> method, that creates the weights of the layer (this is just a style convention since you can create weights in <code translate="no" dir="ltr">__init__</code>, as well).</li> </ul> <p>To learn more about creating layers from scratch, read <a href="/guide/keras/making_new_layers_and_models_via_subclassing">custom layers and models</a> guide.</p> <p>The following is a basic implementation of <a href="https://www.tensorflow.org/api_docs/python/tf/keras/layers/Dense"><code translate="no" dir="ltr">keras.layers.Dense</code></a>:</p> <pre class="prettyprint lang-python" translate="no" dir="ltr"><code translate="no" dir="ltr">class CustomDense(layers.Layer): def __init__(self, units=32): super().__init__() self.units = units def build(self, input_shape): self.w = self.add_weight( shape=(input_shape[-1], self.units), initializer="random_normal", trainable=True, ) self.b = self.add_weight( shape=(self.units,), initializer="random_normal", trainable=True ) def call(self, inputs): return tf.matmul(inputs, self.w) + self.b inputs = keras.Input((4,)) outputs = CustomDense(10)(inputs) model = keras.Model(inputs, outputs) </code></pre> <p>For serialization support in your custom layer, define a <code translate="no" dir="ltr">get_config</code> method that returns the constructor arguments of the layer instance:</p> <pre class="prettyprint lang-python" translate="no" dir="ltr"><code translate="no" dir="ltr">@keras.saving.register_keras_serializable() class CustomDense(layers.Layer): def __init__(self, units=32): super().__init__() self.units = units def build(self, input_shape): self.w = self.add_weight( shape=(input_shape[-1], self.units), initializer="random_normal", trainable=True, ) self.b = self.add_weight( shape=(self.units,), initializer="random_normal", trainable=True ) def call(self, inputs): return tf.matmul(inputs, self.w) + self.b def get_config(self): return {"units": self.units} inputs = keras.Input((4,)) outputs = CustomDense(10)(inputs) model = keras.Model(inputs, outputs) config = model.get_config() new_model = keras.Model.from_config(config) </code></pre> <p>Optionally, implement the class method <code translate="no" dir="ltr">from_config(cls, config)</code> which is used when recreating a layer instance given its config dictionary. The default implementation of <code translate="no" dir="ltr">from_config</code> is:</p> <pre class="prettyprint lang-python" translate="no" dir="ltr"><code translate="no" dir="ltr">def from_config(cls, config): return cls(**config) </code></pre> <h2 id="when_to_use_the_functional_api" data-text="When to use the functional API" tabindex="-1">When to use the functional API</h2> <p>Should you use the Keras functional API to create a new model, or just subclass the <code translate="no" dir="ltr">Model</code> class directly? In general, the functional API is higher-level, easier and safer, and has a number of features that subclassed models do not support.</p> <p>However, model subclassing provides greater flexibility when building models that are not easily expressible as directed acyclic graphs of layers. For example, you could not implement a Tree-RNN with the functional API and would have to subclass <code translate="no" dir="ltr">Model</code> directly.</p> <p>For an in-depth look at the differences between the functional API and model subclassing, read <a href="https://blog.tensorflow.org/2019/01/what-are-symbolic-and-imperative-apis.html">What are Symbolic and Imperative APIs in TensorFlow 2.0?</a>.</p> <h3 id="functional_api_strengths" data-text="Functional API strengths:" tabindex="-1">Functional API strengths:</h3> <p>The following properties are also true for Sequential models (which are also data structures), but are not true for subclassed models (which are Python bytecode, not data structures).</p> <h4 id="less_verbose" data-text="Less verbose" tabindex="-1">Less verbose</h4> <p>There is no <code translate="no" dir="ltr">super().__init__(...)</code>, no <code translate="no" dir="ltr">def call(self, ...):</code>, etc.</p> <p>Compare:</p> <pre class="prettyprint lang-python" translate="no" dir="ltr"><code translate="no" dir="ltr">inputs = keras.Input(shape=(32,)) x = layers.Dense(64, activation='relu')(inputs) outputs = layers.Dense(10)(x) mlp = keras.Model(inputs, outputs) </code></pre> <p>With the subclassed version:</p> <pre class="prettyprint lang-python" translate="no" dir="ltr"><code translate="no" dir="ltr">class MLP(keras.Model): def __init__(self, **kwargs): super().__init__(**kwargs) self.dense_1 = layers.Dense(64, activation='relu') self.dense_2 = layers.Dense(10) def call(self, inputs): x = self.dense_1(inputs) return self.dense_2(x) # Instantiate the model. mlp = MLP() # Necessary to create the model's state. # The model doesn't have a state until it's called at least once. _ = mlp(tf.zeros((1, 32))) </code></pre> <h4 id="model_validation_while_defining_its_connectivity_graph" data-text="Model validation while defining its connectivity graph" tabindex="-1">Model validation while defining its connectivity graph</h4> <p>In the functional API, the input specification (shape and dtype) is created in advance (using <code translate="no" dir="ltr">Input</code>). Every time you call a layer, the layer checks that the specification passed to it matches its assumptions, and it will raise a helpful error message if not.</p> <p>This guarantees that any model you can build with the functional API will run. All debugging -- other than convergence-related debugging -- happens statically during the model construction and not at execution time. This is similar to type checking in a compiler.</p> <h4 id="a_functional_model_is_plottable_and_inspectable" data-text="A functional model is plottable and inspectable" tabindex="-1">A functional model is plottable and inspectable</h4> <p>You can plot the model as a graph, and you can easily access intermediate nodes in this graph. For example, to extract and reuse the activations of intermediate layers (as seen in a previous example):</p> <pre class="prettyprint lang-python" translate="no" dir="ltr"><code translate="no" dir="ltr">features_list = [layer.output for layer in vgg19.layers] feat_extraction_model = keras.Model(inputs=vgg19.input, outputs=features_list) </code></pre> <h4 id="a_functional_model_can_be_serialized_or_cloned" data-text="A functional model can be serialized or cloned" tabindex="-1">A functional model can be serialized or cloned</h4> <p>Because a functional model is a data structure rather than a piece of code, it is safely serializable and can be saved as a single file that allows you to recreate the exact same model without having access to any of the original code. See the <a href="/guide/keras/serialization_and_saving">serialization & saving guide</a>.</p> <p>To serialize a subclassed model, it is necessary for the implementer to specify a <code translate="no" dir="ltr">get_config()</code> and <code translate="no" dir="ltr">from_config()</code> method at the model level.</p> <h3 id="functional_api_weakness" data-text="Functional API weakness:" tabindex="-1">Functional API weakness:</h3> <h4 id="it_does_not_support_dynamic_architectures" data-text="It does not support dynamic architectures" tabindex="-1">It does not support dynamic architectures</h4> <p>The functional API treats models as DAGs of layers. This is true for most deep learning architectures, but not all -- for example, recursive networks or Tree RNNs do not follow this assumption and cannot be implemented in the functional API.</p> <h2 id="mix-and-match_api_styles" data-text="Mix-and-match API styles" tabindex="-1">Mix-and-match API styles</h2> <p>Choosing between the functional API or Model subclassing isn't a binary decision that restricts you into one category of models. All models in the <code translate="no" dir="ltr">keras</code> API can interact with each other, whether they're <code translate="no" dir="ltr">Sequential</code> models, functional models, or subclassed models that are written from scratch.</p> <p>You can always use a functional model or <code translate="no" dir="ltr">Sequential</code> model as part of a subclassed model or layer:</p> <pre class="prettyprint lang-python" translate="no" dir="ltr"><code translate="no" dir="ltr">units = 32 timesteps = 10 input_dim = 5 # Define a Functional model inputs = keras.Input((None, units)) x = layers.GlobalAveragePooling1D()(inputs) outputs = layers.Dense(1)(x) model = keras.Model(inputs, outputs) @keras.saving.register_keras_serializable() class CustomRNN(layers.Layer): def __init__(self): super().__init__() self.units = units self.projection_1 = layers.Dense(units=units, activation="tanh") self.projection_2 = layers.Dense(units=units, activation="tanh") # Our previously-defined Functional model self.classifier = model def call(self, inputs): outputs = [] state = tf.zeros(shape=(inputs.shape[0], self.units)) for t in range(inputs.shape[1]): x = inputs[:, t, :] h = self.projection_1(x) y = h + self.projection_2(state) state = y outputs.append(y) features = tf.stack(outputs, axis=1) print(features.shape) return self.classifier(features) rnn_model = CustomRNN() _ = rnn_model(tf.zeros((1, timesteps, input_dim))) </code></pre> <pre class="tfo-notebook-code-cell-output" translate="no" dir="ltr"> (1, 10, 32) </pre> <p>You can use any subclassed layer or model in the functional API as long as it implements a <code translate="no" dir="ltr">call</code> method that follows one of the following patterns:</p> <ul> <li><code translate="no" dir="ltr">call(self, inputs, **kwargs)</code> -- Where <code translate="no" dir="ltr">inputs</code> is a tensor or a nested structure of tensors (e.g. a list of tensors), and where <code translate="no" dir="ltr">**kwargs</code> are non-tensor arguments (non-inputs).</li> <li><code translate="no" dir="ltr">call(self, inputs, training=None, **kwargs)</code> -- Where <code translate="no" dir="ltr">training</code> is a boolean indicating whether the layer should behave in training mode and inference mode.</li> <li><code translate="no" dir="ltr">call(self, inputs, mask=None, **kwargs)</code> -- Where <code translate="no" dir="ltr">mask</code> is a boolean mask tensor (useful for RNNs, for instance).</li> <li><code translate="no" dir="ltr">call(self, inputs, training=None, mask=None, **kwargs)</code> -- Of course, you can have both masking and training-specific behavior at the same time.</li> </ul> <p>Additionally, if you implement the <code translate="no" dir="ltr">get_config</code> method on your custom Layer or model, the functional models you create will still be serializable and cloneable.</p> <p>Here's a quick example of a custom RNN, written from scratch, being used in a functional model:</p> <pre class="prettyprint lang-python" translate="no" dir="ltr"><code translate="no" dir="ltr">units = 32 timesteps = 10 input_dim = 5 batch_size = 16 @keras.saving.register_keras_serializable() class CustomRNN(layers.Layer): def __init__(self): super().__init__() self.units = units self.projection_1 = layers.Dense(units=units, activation="tanh") self.projection_2 = layers.Dense(units=units, activation="tanh") self.classifier = layers.Dense(1) def call(self, inputs): outputs = [] state = tf.zeros(shape=(inputs.shape[0], self.units)) for t in range(inputs.shape[1]): x = inputs[:, t, :] h = self.projection_1(x) y = h + self.projection_2(state) state = y outputs.append(y) features = tf.stack(outputs, axis=1) return self.classifier(features) # Note that you specify a static batch size for the inputs with the `batch_shape` # arg, because the inner computation of `CustomRNN` requires a static batch size # (when you create the `state` zeros tensor). inputs = keras.Input(batch_shape=(batch_size, timesteps, input_dim)) x = layers.Conv1D(32, 3)(inputs) outputs = CustomRNN()(x) model = keras.Model(inputs, outputs) rnn_model = CustomRNN() _ = rnn_model(tf.zeros((1, 10, 5))) </code></pre> </div> <devsite-thumb-rating position="footer"> </devsite-thumb-rating> <div class="devsite-floating-action-buttons"> </div> </article> <devsite-content-footer class="nocontent"> <p>Except as otherwise noted, the content of this page is licensed under the <a href="https://creativecommons.org/licenses/by/4.0/">Creative Commons Attribution 4.0 License</a>, and code samples are licensed under the <a href="https://www.apache.org/licenses/LICENSE-2.0">Apache 2.0 License</a>. For details, see the <a href="https://developers.google.com/site-policies">Google Developers Site Policies</a>. Java is a registered trademark of Oracle and/or its affiliates.</p> <p>Last updated 2024-04-12 UTC.</p> </devsite-content-footer> <devsite-notification > </devsite-notification> <div class="devsite-content-data"> <template class="devsite-content-data-template"> [[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2024-04-12 UTC."],[],[]] </template> </div> </devsite-content> </main> <devsite-footer-promos class="devsite-footer"> </devsite-footer-promos> <devsite-footer-linkboxes class="devsite-footer"> <nav class="devsite-footer-linkboxes nocontent" aria-label="Footer links"> <ul class="devsite-footer-linkboxes-list"> <li class="devsite-footer-linkbox "> <h3 class="devsite-footer-linkbox-heading no-link">Stay connected</h3> <ul class="devsite-footer-linkbox-list"> <li class="devsite-footer-linkbox-item"> <a href="//blog.tensorflow.org" class="devsite-footer-linkbox-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Footer Link (index 1)" > Blog </a> </li> <li class="devsite-footer-linkbox-item"> <a href="//discuss.tensorflow.org" class="devsite-footer-linkbox-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Footer Link (index 2)" > Forum </a> </li> <li class="devsite-footer-linkbox-item"> <a href="//github.com/tensorflow/" class="devsite-footer-linkbox-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Footer Link (index 3)" > GitHub </a> </li> <li class="devsite-footer-linkbox-item"> <a href="//twitter.com/tensorflow" class="devsite-footer-linkbox-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Footer Link (index 4)" > Twitter </a> </li> <li class="devsite-footer-linkbox-item"> <a href="//youtube.com/tensorflow" class="devsite-footer-linkbox-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Footer Link (index 5)" > YouTube </a> </li> </ul> </li> <li class="devsite-footer-linkbox "> <h3 class="devsite-footer-linkbox-heading no-link">Support</h3> <ul class="devsite-footer-linkbox-list"> <li class="devsite-footer-linkbox-item"> <a href="//github.com/tensorflow/tensorflow/issues" class="devsite-footer-linkbox-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Footer Link (index 1)" > Issue tracker </a> </li> <li class="devsite-footer-linkbox-item"> <a href="//github.com/tensorflow/tensorflow/blob/master/RELEASE.md" class="devsite-footer-linkbox-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Footer Link (index 2)" > Release notes </a> </li> <li class="devsite-footer-linkbox-item"> <a href="//stackoverflow.com/questions/tagged/tensorflow" class="devsite-footer-linkbox-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Footer Link (index 3)" > Stack Overflow </a> </li> <li class="devsite-footer-linkbox-item"> <a href="/extras/tensorflow_brand_guidelines.pdf" class="devsite-footer-linkbox-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Footer Link (index 4)" > Brand guidelines </a> </li> <li class="devsite-footer-linkbox-item"> <a href="/about/bib" class="devsite-footer-linkbox-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Footer Link (index 5)" > Cite TensorFlow </a> </li> </ul> </li> </ul> </nav> </devsite-footer-linkboxes> <devsite-footer-utility class="devsite-footer"> <div class="devsite-footer-utility nocontent"> <nav class="devsite-footer-utility-links" aria-label="Utility links"> <ul class="devsite-footer-utility-list"> <li class="devsite-footer-utility-item "> <a class="devsite-footer-utility-link gc-analytics-event" href="//policies.google.com/terms" data-category="Site-Wide Custom Events" data-label="Footer Terms link" > Terms </a> </li> <li class="devsite-footer-utility-item "> <a class="devsite-footer-utility-link gc-analytics-event" href="//policies.google.com/privacy" data-category="Site-Wide Custom Events" data-label="Footer Privacy link" > Privacy </a> </li> <li class="devsite-footer-utility-item glue-cookie-notification-bar-control"> <a class="devsite-footer-utility-link gc-analytics-event" href="#" data-category="Site-Wide Custom Events" data-label="Footer Manage cookies link" aria-hidden="true" > Manage cookies </a> </li> <li class="devsite-footer-utility-item devsite-footer-utility-button"> <span class="devsite-footer-utility-description">Sign up for the TensorFlow newsletter</span> <a class="devsite-footer-utility-link gc-analytics-event" href="//www.tensorflow.org/subscribe" data-category="Site-Wide Custom Events" data-label="Footer Subscribe link" > Subscribe </a> </li> </ul> <devsite-language-selector> <ul role="presentation"> <li role="presentation"> <a role="menuitem" lang="en" >English</a> </li> <li role="presentation"> <a role="menuitem" lang="zh_cn" >中文 – 简体</a> </li> </ul> </devsite-language-selector> </nav> </div> </devsite-footer-utility> <devsite-panel></devsite-panel> </section></section> <devsite-sitemask></devsite-sitemask> <devsite-snackbar></devsite-snackbar> <devsite-tooltip ></devsite-tooltip> <devsite-heading-link></devsite-heading-link> <devsite-analytics> <script type="application/json" analytics>[{"dimensions": {"dimension12": false, "dimension6": "en", "dimension5": "en", "dimension1": "Signed out", "dimension4": "TensorFlow Core", "dimension3": false}, "gaid": "UA-69864048-1", "metrics": {"ratings_value": "metric1", "ratings_count": "metric2"}, "purpose": 0}]</script> <script type="application/json" tag-management>{"at": "True", "ga4": [], "ga4p": [], "gtm": [{"id": "GTM-MXSL34P", "purpose": 0}], "parameters": {"internalUser": "False", "language": {"machineTranslated": "False", "requested": "en", "served": "en"}, "pageType": "article", "projectName": "TensorFlow Core", "signedIn": "False", "tenant": "tensorflow", "recommendations": {"sourcePage": "", "sourceType": 0, "sourceRank": 0, "sourceIdenticalDescriptions": 0, "sourceTitleWords": 0, "sourceDescriptionWords": 0, "experiment": ""}, "experiment": {"ids": ""}}}</script> </devsite-analytics> <devsite-badger></devsite-badger> <script nonce="OU19WnIPDbFzppFpGpOI/HTxPKoU2e"> (function(d,e,v,s,i,t,E){d['GoogleDevelopersObject']=i; t=e.createElement(v);t.async=1;t.src=s;E=e.getElementsByTagName(v)[0]; E.parentNode.insertBefore(t,E);})(window, document, 'script', 'https://www.gstatic.com/devrel-devsite/prod/v870e399c64f7c43c99a3043db4b3a74327bb93d0914e84a0c3dba90bbfd67625/tensorflow/js/app_loader.js', '[15,"en",null,"/js/devsite_app_module.js","https://www.gstatic.com/devrel-devsite/prod/v870e399c64f7c43c99a3043db4b3a74327bb93d0914e84a0c3dba90bbfd67625","https://www.gstatic.com/devrel-devsite/prod/v870e399c64f7c43c99a3043db4b3a74327bb93d0914e84a0c3dba90bbfd67625/tensorflow","https://tensorflow-dot-devsite-v2-prod-3p.appspot.com",null,null,["/_pwa/tensorflow/manifest.json","https://www.gstatic.com/devrel-devsite/prod/v870e399c64f7c43c99a3043db4b3a74327bb93d0914e84a0c3dba90bbfd67625/images/video-placeholder.svg","https://www.gstatic.com/devrel-devsite/prod/v870e399c64f7c43c99a3043db4b3a74327bb93d0914e84a0c3dba90bbfd67625/tensorflow/images/favicon.png","https://www.gstatic.com/devrel-devsite/prod/v870e399c64f7c43c99a3043db4b3a74327bb93d0914e84a0c3dba90bbfd67625/tensorflow/images/lockup.svg","https://fonts.googleapis.com/css?family=Google+Sans:400,500|Roboto:400,400italic,500,500italic,700,700italic|Roboto+Mono:400,500,700&display=swap"],1,null,[1,6,8,12,14,17,21,25,50,52,63,70,75,76,80,87,91,92,93,97,98,100,101,102,103,104,105,107,108,109,110,112,113,116,117,118,120,122,124,125,126,127,129,130,131,132,133,134,135,136,138,140,141,147,148,149,151,152,156,157,158,159,161,163,164,168,169,170,179,180,182,183,186,191,193,196],"AIzaSyCNm9YxQumEXwGJgTDjxoxXK6m1F-9720Q","AIzaSyCc76DZePGtoyUjqKrLdsMGk_ry7sljLbY","www.tensorflow.org","AIzaSyB9bqgQ2t11WJsOX8qNsCQ6U-w91mmqF-I","AIzaSyAdYnStPdzjcJJtQ0mvIaeaMKj7_t6J_Fg",null,null,null,["CloudShell__cloud_shell_button","SignIn__enable_refresh_access_tokens","Profiles__enable_profile_collections","Concierge__enable_pushui","Profiles__enable_public_developer_profiles","Profiles__enable_awarding_url","MiscFeatureFlags__emergency_css","MiscFeatureFlags__developers_footer_dark_image","CloudShell__cloud_code_overflow_menu","Profiles__enable_complete_playlist_endpoint","DevPro__enable_developer_subscriptions","MiscFeatureFlags__enable_view_transitions","Profiles__enable_dashboard_curated_recommendations","MiscFeatureFlags__enable_explain_this_code","TpcFeatures__enable_required_headers","Experiments__reqs_query_experiments","Search__enable_dynamic_content_confidential_banner","Search__enable_ai_eligibility_checks","Analytics__enable_clearcut_logging","Cloud__enable_cloudx_ping","MiscFeatureFlags__enable_variable_operator","BookNav__enable_tenant_cache_key","Cloud__enable_legacy_calculator_redirect","Profiles__require_profile_eligibility_for_signin","EngEduTelemetry__enable_engedu_telemetry","Cloud__enable_cloudx_experiment_ids","Cloud__enable_cloud_shell","Profiles__enable_completecodelab_endpoint","Cloud__enable_cloud_facet_chat","Cloud__enable_cloud_dlp_service","MiscFeatureFlags__developers_footer_image","Profiles__enable_recognition_badges","Profiles__enable_release_notes_notifications","Search__enable_page_map","MiscFeatureFlags__enable_project_variables","TpcFeatures__enable_mirror_tenant_redirects","DevPro__enable_cloud_innovators_plus","Cloud__enable_free_trial_server_call","Profiles__enable_page_saving","Profiles__enable_developer_profiles_callout","Search__enable_suggestions_from_borg","Cloud__enable_cloud_shell_fte_user_flow","MiscFeatureFlags__enable_firebase_utm","Cloud__enable_llm_concierge_chat"],null,null,"AIzaSyA58TaKli1DculwmAmbpzLVGuWc8eCQgQc","https://developerscontentserving-pa.googleapis.com","AIzaSyDWBU60w0P9hEkr29kkksYs8Z7gvZ8u_wc","https://developerscontentsearch-pa.googleapis.com",2,4,null,"https://developerprofiles-pa.googleapis.com",[15,"tensorflow","TensorFlow","www.tensorflow.org",null,"tensorflow-dot-devsite-v2-prod-3p.appspot.com",null,null,[null,1,null,null,null,null,null,null,null,null,null,[1],null,null,null,null,null,null,[1],[1,null,null,[1]],null,null,null,[1,null,1],[1,1,null,1,1]],null,[25,null,null,null,null,null,"/images/lockup.svg","/images/logo.png",null,null,null,1,1,null,null,null,null,null,null,null,null,1,null,null,null,null,[]],[],null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,[6,1],null,[[],[1,1]],[[["UA-69864048-1"],["UA-69864048-4"],null,null,["UA-69864048-5"],["GTM-MXSL34P"],null,null,[["UA-69864048-1",1]],null,[["UA-69864048-5",1]],[["GTM-MXSL34P",1]],1],[[4,3],[1,1],[6,5],[3,2],[12,8],[5,4]],[[2,2],[1,1]]],null,4]]') </script> <devsite-a11y-announce></devsite-a11y-announce> </body> </html>