CINXE.COM
Neural machine translation with a Transformer and Keras | Text | TensorFlow
<!doctype html> <html lang="en" dir="ltr"> <head> <meta name="google-signin-client-id" content="157101835696-ooapojlodmuabs2do2vuhhnf90bccmoi.apps.googleusercontent.com"> <meta name="google-signin-scope" content="profile email https://www.googleapis.com/auth/developerprofiles https://www.googleapis.com/auth/developerprofiles.award"> <meta property="og:site_name" content="TensorFlow"> <meta property="og:type" content="website"><meta name="theme-color" content="#ff6f00"><meta charset="utf-8"> <meta content="IE=Edge" http-equiv="X-UA-Compatible"> <meta name="viewport" content="width=device-width, initial-scale=1"> <link rel="manifest" href="/_pwa/tensorflow/manifest.json" crossorigin="use-credentials"> <link rel="preconnect" href="//www.gstatic.com" crossorigin> <link rel="preconnect" href="//fonts.gstatic.com" crossorigin> <link rel="preconnect" href="//fonts.googleapis.com" crossorigin> <link rel="preconnect" href="//apis.google.com" crossorigin> <link rel="preconnect" href="//www.google-analytics.com" crossorigin><link rel="stylesheet" href="//fonts.googleapis.com/css?family=Google+Sans:400,500|Roboto:400,400italic,500,500italic,700,700italic|Roboto+Mono:400,500,700&display=swap"> <link rel="stylesheet" href="//fonts.googleapis.com/css2?family=Material+Icons&family=Material+Symbols+Outlined&display=block"><link rel="stylesheet" href="https://www.gstatic.com/devrel-devsite/prod/v870e399c64f7c43c99a3043db4b3a74327bb93d0914e84a0c3dba90bbfd67625/tensorflow/css/app.css"> <link rel="shortcut icon" href="https://www.gstatic.com/devrel-devsite/prod/v870e399c64f7c43c99a3043db4b3a74327bb93d0914e84a0c3dba90bbfd67625/tensorflow/images/favicon.png"> <link rel="apple-touch-icon" href="https://www.gstatic.com/devrel-devsite/prod/v870e399c64f7c43c99a3043db4b3a74327bb93d0914e84a0c3dba90bbfd67625/tensorflow/images/apple-touch-icon-180x180.png"><link rel="canonical" href="https://www.tensorflow.org/text/tutorials/transformer"><link rel="search" type="application/opensearchdescription+xml" title="TensorFlow" href="https://www.tensorflow.org/s/opensearch.xml"> <link rel="alternate" hreflang="en" href="https://www.tensorflow.org/text/tutorials/transformer" /><link rel="alternate" hreflang="x-default" href="https://www.tensorflow.org/text/tutorials/transformer" /><link rel="alternate" hreflang="ar" href="https://www.tensorflow.org/text/tutorials/transformer?hl=ar" /><link rel="alternate" hreflang="bn" href="https://www.tensorflow.org/text/tutorials/transformer?hl=bn" /><link rel="alternate" hreflang="fa" href="https://www.tensorflow.org/text/tutorials/transformer?hl=fa" /><link rel="alternate" hreflang="fr" href="https://www.tensorflow.org/text/tutorials/transformer?hl=fr" /><link rel="alternate" hreflang="he" href="https://www.tensorflow.org/text/tutorials/transformer?hl=he" /><link rel="alternate" hreflang="hi" href="https://www.tensorflow.org/text/tutorials/transformer?hl=hi" /><link rel="alternate" hreflang="id" href="https://www.tensorflow.org/text/tutorials/transformer?hl=id" /><link rel="alternate" hreflang="it" href="https://www.tensorflow.org/text/tutorials/transformer?hl=it" /><link rel="alternate" hreflang="ja" href="https://www.tensorflow.org/text/tutorials/transformer?hl=ja" /><link rel="alternate" hreflang="ko" href="https://www.tensorflow.org/text/tutorials/transformer?hl=ko" /><link rel="alternate" hreflang="pl" href="https://www.tensorflow.org/text/tutorials/transformer?hl=pl" /><link rel="alternate" hreflang="pt-BR" href="https://www.tensorflow.org/text/tutorials/transformer?hl=pt-br" /><link rel="alternate" hreflang="ru" href="https://www.tensorflow.org/text/tutorials/transformer?hl=ru" /><link rel="alternate" hreflang="es-419" href="https://www.tensorflow.org/text/tutorials/transformer?hl=es-419" /><link rel="alternate" hreflang="th" href="https://www.tensorflow.org/text/tutorials/transformer?hl=th" /><link rel="alternate" hreflang="tr" href="https://www.tensorflow.org/text/tutorials/transformer?hl=tr" /><link rel="alternate" hreflang="vi" href="https://www.tensorflow.org/text/tutorials/transformer?hl=vi" /><title>Neural machine translation with a Transformer and Keras | Text | TensorFlow</title> <meta property="og:title" content="Neural machine translation with a Transformer and Keras | Text | TensorFlow"><meta property="og:url" content="https://www.tensorflow.org/text/tutorials/transformer"><meta property="og:image" content="https://www.tensorflow.org/static/images/tf_logo_social.png"> <meta property="og:image:width" content="1200"> <meta property="og:image:height" content="675"><meta property="og:locale" content="en"><meta name="twitter:card" content="summary_large_image"><script type="application/ld+json"> { "@context": "https://schema.org", "@type": "Article", "headline": "Neural machine translation with a Transformer and Keras" } </script><script type="application/ld+json"> { "@context": "https://schema.org", "@type": "BreadcrumbList", "itemListElement": [{ "@type": "ListItem", "position": 1, "name": "Text", "item": "https://www.tensorflow.org/text" },{ "@type": "ListItem", "position": 2, "name": "Neural machine translation with a Transformer and Keras", "item": "https://www.tensorflow.org/text/tutorials/transformer" }] } </script> <link rel="stylesheet" href="/extras.css"></head> <body class="" template="page" theme="tensorflow-theme" type="article" layout="docs" display-toc pending> <devsite-progress type="indeterminate" id="app-progress"></devsite-progress> <section class="devsite-wrapper"> <devsite-cookie-notification-bar></devsite-cookie-notification-bar><devsite-header role="banner"> <div class="devsite-header--inner nocontent"> <div class="devsite-top-logo-row-wrapper-wrapper"> <div class="devsite-top-logo-row-wrapper"> <div class="devsite-top-logo-row"> <button type="button" id="devsite-hamburger-menu" class="devsite-header-icon-button button-flat material-icons gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Navigation menu button" visually-hidden aria-label="Open menu"> </button> <div class="devsite-product-name-wrapper"> <a href="/" class="devsite-site-logo-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Site logo" track-type="globalNav" track-name="tensorFlow" track-metadata-position="nav" track-metadata-eventDetail="nav"> <picture> <img src="https://www.gstatic.com/devrel-devsite/prod/v870e399c64f7c43c99a3043db4b3a74327bb93d0914e84a0c3dba90bbfd67625/tensorflow/images/lockup.svg" class="devsite-site-logo" alt="TensorFlow"> </picture> </a> <span class="devsite-product-name"> <ul class="devsite-breadcrumb-list" > <li class="devsite-breadcrumb-item "> </li> </ul> </span> </div> <div class="devsite-top-logo-row-middle"> <div class="devsite-header-upper-tabs"> <devsite-tabs class="upper-tabs"> <nav class="devsite-tabs-wrapper" aria-label="Upper tabs"> <tab > <a href="https://www.tensorflow.org/install" track-metadata-eventdetail="https://www.tensorflow.org/install" class="devsite-tabs-content gc-analytics-event " track-type="nav" track-metadata-position="nav - install" track-metadata-module="primary nav" data-category="Site-Wide Custom Events" data-label="Tab: Install" track-name="install" > Install </a> </tab> <tab class="devsite-dropdown "> <a href="https://www.tensorflow.org/learn" track-metadata-eventdetail="https://www.tensorflow.org/learn" class="devsite-tabs-content gc-analytics-event " track-type="nav" track-metadata-position="nav - learn" track-metadata-module="primary nav" data-category="Site-Wide Custom Events" data-label="Tab: Learn" track-name="learn" > Learn </a> <a href="#" role="button" aria-haspopup="true" aria-expanded="false" aria-label="Dropdown menu for Learn" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/learn" track-metadata-position="nav - learn" track-metadata-module="primary nav" data-category="Site-Wide Custom Events" data-label="Tab: Learn" track-name="learn" class="devsite-tabs-dropdown-toggle devsite-icon devsite-icon-arrow-drop-down"></a> <div class="devsite-tabs-dropdown" aria-label="submenu" hidden> <div class="devsite-tabs-dropdown-content"> <div class="devsite-tabs-dropdown-column tfo-menu-column-learn"> <ul class="devsite-tabs-dropdown-section "> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/learn" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/learn" track-metadata-position="nav - learn" track-metadata-module="tertiary nav" tooltip > <div class="devsite-nav-item-title"> Introduction </div> <div class="devsite-nav-item-description"> New to TensorFlow? </div> </a> </li> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/tutorials" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/tutorials" track-metadata-position="nav - learn" track-metadata-module="tertiary nav" tooltip > <div class="devsite-nav-item-title"> Tutorials </div> <div class="devsite-nav-item-description"> Learn how to use TensorFlow with end-to-end examples </div> </a> </li> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/guide" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/guide" track-metadata-position="nav - learn" track-metadata-module="tertiary nav" tooltip > <div class="devsite-nav-item-title"> Guide </div> <div class="devsite-nav-item-description"> Learn framework concepts and components </div> </a> </li> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/resources/learn-ml" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/resources/learn-ml" track-metadata-position="nav - learn" track-metadata-module="tertiary nav" tooltip > <div class="devsite-nav-item-title"> Learn ML </div> <div class="devsite-nav-item-description"> Educational resources to master your path with TensorFlow </div> </a> </li> </ul> </div> </div> </div> </tab> <tab class="devsite-dropdown "> <a href="https://www.tensorflow.org/api" track-metadata-eventdetail="https://www.tensorflow.org/api" class="devsite-tabs-content gc-analytics-event " track-type="nav" track-metadata-position="nav - api" track-metadata-module="primary nav" data-category="Site-Wide Custom Events" data-label="Tab: API" track-name="api" > API </a> <a href="#" role="button" aria-haspopup="true" aria-expanded="false" aria-label="Dropdown menu for API" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/api" track-metadata-position="nav - api" track-metadata-module="primary nav" data-category="Site-Wide Custom Events" data-label="Tab: API" track-name="api" class="devsite-tabs-dropdown-toggle devsite-icon devsite-icon-arrow-drop-down"></a> <div class="devsite-tabs-dropdown" aria-label="submenu" hidden> <div class="devsite-tabs-dropdown-content"> <div class="devsite-tabs-dropdown-column "> <ul class="devsite-tabs-dropdown-section "> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/api/stable" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/api/stable" track-metadata-position="nav - api" track-metadata-module="tertiary nav" tooltip > <div class="devsite-nav-item-title"> TensorFlow (v2.16.1) </div> </a> </li> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/versions" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/versions" track-metadata-position="nav - api" track-metadata-module="tertiary nav" tooltip > <div class="devsite-nav-item-title"> Versions… </div> </a> </li> </ul> </div> <div class="devsite-tabs-dropdown-column "> <ul class="devsite-tabs-dropdown-section "> <li class="devsite-nav-item"> <a href="https://js.tensorflow.org/api/latest/" track-type="nav" track-metadata-eventdetail="https://js.tensorflow.org/api/latest/" track-metadata-position="nav - api" track-metadata-module="tertiary nav" tooltip > <div class="devsite-nav-item-title"> TensorFlow.js </div> </a> </li> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/lite/api_docs" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/lite/api_docs" track-metadata-position="nav - api" track-metadata-module="tertiary nav" tooltip > <div class="devsite-nav-item-title"> TensorFlow Lite </div> </a> </li> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/tfx/api_docs" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/tfx/api_docs" track-metadata-position="nav - api" track-metadata-module="tertiary nav" tooltip > <div class="devsite-nav-item-title"> TFX </div> </a> </li> </ul> </div> </div> </div> </tab> <tab class="devsite-dropdown devsite-active "> <a href="https://www.tensorflow.org/resources" track-metadata-eventdetail="https://www.tensorflow.org/resources" class="devsite-tabs-content gc-analytics-event " track-type="nav" track-metadata-position="nav - resources" track-metadata-module="primary nav" aria-label="Resources, selected" data-category="Site-Wide Custom Events" data-label="Tab: Resources" track-name="resources" > Resources </a> <a href="#" role="button" aria-haspopup="true" aria-expanded="false" aria-label="Dropdown menu for Resources" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/resources" track-metadata-position="nav - resources" track-metadata-module="primary nav" data-category="Site-Wide Custom Events" data-label="Tab: Resources" track-name="resources" class="devsite-tabs-dropdown-toggle devsite-icon devsite-icon-arrow-drop-down"></a> <div class="devsite-tabs-dropdown" aria-label="submenu" hidden> <div class="devsite-tabs-dropdown-content"> <div class="devsite-tabs-dropdown-column "> <ul class="devsite-tabs-dropdown-section "> <li class="devsite-nav-title" role="heading" tooltip>LIBRARIES</li> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/js" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/js" track-metadata-position="nav - resources" track-metadata-module="tertiary nav" track-metadata-module_headline="libraries" tooltip > <div class="devsite-nav-item-title"> TensorFlow.js </div> <div class="devsite-nav-item-description"> Develop web ML applications in JavaScript </div> </a> </li> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/lite" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/lite" track-metadata-position="nav - resources" track-metadata-module="tertiary nav" track-metadata-module_headline="libraries" tooltip > <div class="devsite-nav-item-title"> TensorFlow Lite </div> <div class="devsite-nav-item-description"> Deploy ML on mobile, microcontrollers and other edge devices </div> </a> </li> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/tfx" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/tfx" track-metadata-position="nav - resources" track-metadata-module="tertiary nav" track-metadata-module_headline="libraries" tooltip > <div class="devsite-nav-item-title"> TFX </div> <div class="devsite-nav-item-description"> Build production ML pipelines </div> </a> </li> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/resources/libraries-extensions" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/resources/libraries-extensions" track-metadata-position="nav - resources" track-metadata-module="tertiary nav" track-metadata-module_headline="libraries" tooltip > <div class="devsite-nav-item-title"> All libraries </div> <div class="devsite-nav-item-description"> Create advanced models and extend TensorFlow </div> </a> </li> </ul> </div> <div class="devsite-tabs-dropdown-column "> <ul class="devsite-tabs-dropdown-section "> <li class="devsite-nav-title" role="heading" tooltip>RESOURCES</li> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/resources/models-datasets" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/resources/models-datasets" track-metadata-position="nav - resources" track-metadata-module="tertiary nav" track-metadata-module_headline="resources" tooltip > <div class="devsite-nav-item-title"> Models & datasets </div> <div class="devsite-nav-item-description"> Pre-trained models and datasets built by Google and the community </div> </a> </li> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/resources/tools" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/resources/tools" track-metadata-position="nav - resources" track-metadata-module="tertiary nav" track-metadata-module_headline="resources" tooltip > <div class="devsite-nav-item-title"> Tools </div> <div class="devsite-nav-item-description"> Tools to support and accelerate TensorFlow workflows </div> </a> </li> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/responsible_ai" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/responsible_ai" track-metadata-position="nav - resources" track-metadata-module="tertiary nav" track-metadata-module_headline="resources" tooltip > <div class="devsite-nav-item-title"> Responsible AI </div> <div class="devsite-nav-item-description"> Resources for every stage of the ML workflow </div> </a> </li> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/resources/recommendation-systems" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/resources/recommendation-systems" track-metadata-position="nav - resources" track-metadata-module="tertiary nav" track-metadata-module_headline="resources" tooltip > <div class="devsite-nav-item-title"> Recommendation systems </div> <div class="devsite-nav-item-description"> Build recommendation systems with open source tools </div> </a> </li> </ul> </div> </div> </div> </tab> <tab class="devsite-dropdown "> <a href="https://www.tensorflow.org/community" track-metadata-eventdetail="https://www.tensorflow.org/community" class="devsite-tabs-content gc-analytics-event " track-type="nav" track-metadata-position="nav - community" track-metadata-module="primary nav" data-category="Site-Wide Custom Events" data-label="Tab: Community" track-name="community" > Community </a> <a href="#" role="button" aria-haspopup="true" aria-expanded="false" aria-label="Dropdown menu for Community" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/community" track-metadata-position="nav - community" track-metadata-module="primary nav" data-category="Site-Wide Custom Events" data-label="Tab: Community" track-name="community" class="devsite-tabs-dropdown-toggle devsite-icon devsite-icon-arrow-drop-down"></a> <div class="devsite-tabs-dropdown" aria-label="submenu" hidden> <div class="devsite-tabs-dropdown-content"> <div class="devsite-tabs-dropdown-column "> <ul class="devsite-tabs-dropdown-section "> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/community/groups" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/community/groups" track-metadata-position="nav - community" track-metadata-module="tertiary nav" tooltip > <div class="devsite-nav-item-title"> Groups </div> <div class="devsite-nav-item-description"> User groups, interest groups and mailing lists </div> </a> </li> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/community/contribute" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/community/contribute" track-metadata-position="nav - community" track-metadata-module="tertiary nav" tooltip > <div class="devsite-nav-item-title"> Contribute </div> <div class="devsite-nav-item-description"> Guide for contributing to code and documentation </div> </a> </li> <li class="devsite-nav-item"> <a href="https://blog.tensorflow.org/" track-type="nav" track-metadata-eventdetail="https://blog.tensorflow.org/" track-metadata-position="nav - community" track-metadata-module="tertiary nav" tooltip > <div class="devsite-nav-item-title"> Blog </div> <div class="devsite-nav-item-description"> Stay up to date with all things TensorFlow </div> </a> </li> <li class="devsite-nav-item"> <a href="https://discuss.tensorflow.org" track-type="nav" track-metadata-eventdetail="https://discuss.tensorflow.org" track-metadata-position="nav - community" track-metadata-module="tertiary nav" tooltip > <div class="devsite-nav-item-title"> Forum </div> <div class="devsite-nav-item-description"> Discussion platform for the TensorFlow community </div> </a> </li> </ul> </div> </div> </div> </tab> <tab class="devsite-dropdown "> <a href="https://www.tensorflow.org/about" track-metadata-eventdetail="https://www.tensorflow.org/about" class="devsite-tabs-content gc-analytics-event " track-type="nav" track-metadata-position="nav - why tensorflow" track-metadata-module="primary nav" data-category="Site-Wide Custom Events" data-label="Tab: Why TensorFlow" track-name="why tensorflow" > Why TensorFlow </a> <a href="#" role="button" aria-haspopup="true" aria-expanded="false" aria-label="Dropdown menu for Why TensorFlow" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/about" track-metadata-position="nav - why tensorflow" track-metadata-module="primary nav" data-category="Site-Wide Custom Events" data-label="Tab: Why TensorFlow" track-name="why tensorflow" class="devsite-tabs-dropdown-toggle devsite-icon devsite-icon-arrow-drop-down"></a> <div class="devsite-tabs-dropdown" aria-label="submenu" hidden> <div class="devsite-tabs-dropdown-content"> <div class="devsite-tabs-dropdown-column "> <ul class="devsite-tabs-dropdown-section "> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/about" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/about" track-metadata-position="nav - why tensorflow" track-metadata-module="tertiary nav" tooltip > <div class="devsite-nav-item-title"> About </div> </a> </li> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/about/case-studies" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/about/case-studies" track-metadata-position="nav - why tensorflow" track-metadata-module="tertiary nav" tooltip > <div class="devsite-nav-item-title"> Case studies </div> </a> </li> </ul> </div> </div> </div> </tab> </nav> </devsite-tabs> </div> <devsite-search enable-signin enable-search enable-suggestions enable-query-completion project-name="Text" tenant-name="TensorFlow" > <form class="devsite-search-form" action="https://www.tensorflow.org/s/results" method="GET"> <div class="devsite-search-container"> <button type="button" search-open class="devsite-search-button devsite-header-icon-button button-flat material-icons" aria-label="Open search"></button> <div class="devsite-searchbox"> <input aria-activedescendant="" aria-autocomplete="list" aria-label="Search" aria-expanded="false" aria-haspopup="listbox" autocomplete="off" class="devsite-search-field devsite-search-query" name="q" placeholder="Search" role="combobox" type="text" value="" > <div class="devsite-search-image material-icons" aria-hidden="true"> </div> <div class="devsite-search-shortcut-icon-container" aria-hidden="true"> <kbd class="devsite-search-shortcut-icon">/</kbd> </div> </div> </div> </form> <button type="button" search-close class="devsite-search-button devsite-header-icon-button button-flat material-icons" aria-label="Close search"></button> </devsite-search> </div> <devsite-language-selector> <ul role="presentation"> <li role="presentation"> <a role="menuitem" lang="en" >English</a> </li> <li role="presentation"> <a role="menuitem" lang="es_419" >Español – América Latina</a> </li> <li role="presentation"> <a role="menuitem" lang="fr" >Français</a> </li> <li role="presentation"> <a role="menuitem" lang="id" >Indonesia</a> </li> <li role="presentation"> <a role="menuitem" lang="it" >Italiano</a> </li> <li role="presentation"> <a role="menuitem" lang="pl" >Polski</a> </li> <li role="presentation"> <a role="menuitem" lang="pt_br" >Português – Brasil</a> </li> <li role="presentation"> <a role="menuitem" lang="vi" >Tiếng Việt</a> </li> <li role="presentation"> <a role="menuitem" lang="tr" >Türkçe</a> </li> <li role="presentation"> <a role="menuitem" lang="ru" >Русский</a> </li> <li role="presentation"> <a role="menuitem" lang="he" >עברית</a> </li> <li role="presentation"> <a role="menuitem" lang="ar" >العربيّة</a> </li> <li role="presentation"> <a role="menuitem" lang="fa" >فارسی</a> </li> <li role="presentation"> <a role="menuitem" lang="hi" >हिंदी</a> </li> <li role="presentation"> <a role="menuitem" lang="bn" >বাংলা</a> </li> <li role="presentation"> <a role="menuitem" lang="th" >ภาษาไทย</a> </li> <li role="presentation"> <a role="menuitem" lang="zh_cn" >中文 – 简体</a> </li> <li role="presentation"> <a role="menuitem" lang="ja" >日本語</a> </li> <li role="presentation"> <a role="menuitem" lang="ko" >한국어</a> </li> </ul> </devsite-language-selector> <a class="devsite-header-link devsite-top-button button gc-analytics-event" href="//github.com/tensorflow" data-category="Site-Wide Custom Events" data-label="Site header link" > GitHub </a> <devsite-user enable-profiles id="devsite-user"> <span class="button devsite-top-button" aria-hidden="true" visually-hidden>Sign in</span> </devsite-user> </div> </div> </div> <div class="devsite-collapsible-section "> <div class="devsite-header-background"> <div class="devsite-product-id-row" > <div class="devsite-product-description-row"> <ul class="devsite-breadcrumb-list" > <li class="devsite-breadcrumb-item "> <a href="https://www.tensorflow.org/text" class="devsite-breadcrumb-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Lower Header" data-value="1" track-type="globalNav" track-name="breadcrumb" track-metadata-position="1" track-metadata-eventdetail="Text" > Text </a> </li> </ul> </div> </div> <div class="devsite-doc-set-nav-row"> <devsite-tabs class="lower-tabs"> <nav class="devsite-tabs-wrapper" aria-label="Lower tabs"> <tab > <a href="https://www.tensorflow.org/text" track-metadata-eventdetail="https://www.tensorflow.org/text" class="devsite-tabs-content gc-analytics-event " track-type="nav" track-metadata-position="nav - overview" track-metadata-module="primary nav" data-category="Site-Wide Custom Events" data-label="Tab: Overview" track-name="overview" > Overview </a> </tab> <tab class="devsite-active"> <a href="https://www.tensorflow.org/text/tutorials" track-metadata-eventdetail="https://www.tensorflow.org/text/tutorials" class="devsite-tabs-content gc-analytics-event " track-type="nav" track-metadata-position="nav - tutorials" track-metadata-module="primary nav" aria-label="Tutorials, selected" data-category="Site-Wide Custom Events" data-label="Tab: Tutorials" track-name="tutorials" > Tutorials </a> </tab> <tab > <a href="https://www.tensorflow.org/text/guide" track-metadata-eventdetail="https://www.tensorflow.org/text/guide" class="devsite-tabs-content gc-analytics-event " track-type="nav" track-metadata-position="nav - guide" track-metadata-module="primary nav" data-category="Site-Wide Custom Events" data-label="Tab: Guide" track-name="guide" > Guide </a> </tab> <tab > <a href="https://www.tensorflow.org/text/api_overview" track-metadata-eventdetail="https://www.tensorflow.org/text/api_overview" class="devsite-tabs-content gc-analytics-event " track-type="nav" track-metadata-position="nav - api" track-metadata-module="primary nav" data-category="Site-Wide Custom Events" data-label="Tab: API" track-name="api" > API </a> </tab> </nav> </devsite-tabs> </div> </div> </div> </div> </devsite-header> <devsite-book-nav scrollbars > <div class="devsite-book-nav-filter" > <span class="filter-list-icon material-icons" aria-hidden="true"></span> <input type="text" placeholder="Filter" aria-label="Type to filter" role="searchbox"> <span class="filter-clear-button hidden" data-title="Clear filter" aria-label="Clear filter" role="button" tabindex="0"></span> </div> <nav class="devsite-book-nav devsite-nav nocontent" aria-label="Side menu"> <div class="devsite-mobile-header"> <button type="button" id="devsite-close-nav" class="devsite-header-icon-button button-flat material-icons gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Close navigation" aria-label="Close navigation"> </button> <div class="devsite-product-name-wrapper"> <a href="/" class="devsite-site-logo-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Site logo" track-type="globalNav" track-name="tensorFlow" track-metadata-position="nav" track-metadata-eventDetail="nav"> <picture> <img src="https://www.gstatic.com/devrel-devsite/prod/v870e399c64f7c43c99a3043db4b3a74327bb93d0914e84a0c3dba90bbfd67625/tensorflow/images/lockup.svg" class="devsite-site-logo" alt="TensorFlow"> </picture> </a> <span class="devsite-product-name"> <ul class="devsite-breadcrumb-list" > <li class="devsite-breadcrumb-item "> </li> </ul> </span> </div> </div> <div class="devsite-book-nav-wrapper"> <div class="devsite-mobile-nav-top"> <ul class="devsite-nav-list"> <li class="devsite-nav-item"> <a href="/install" class="devsite-nav-title gc-analytics-event devsite-nav-has-children " data-category="Site-Wide Custom Events" data-label="Tab: Install" track-name="install" data-category="Site-Wide Custom Events" data-label="Responsive Tab: Install" track-type="globalNav" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Install </span> <span class="devsite-nav-icon material-icons" data-icon="forward" > </span> </a> </li> <li class="devsite-nav-item"> <a href="/learn" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Tab: Learn" track-name="learn" data-category="Site-Wide Custom Events" data-label="Responsive Tab: Learn" track-type="globalNav" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Learn </span> </a> <ul class="devsite-nav-responsive-tabs devsite-nav-has-menu "> <li class="devsite-nav-item"> <span class="devsite-nav-title" tooltip data-category="Site-Wide Custom Events" data-label="Tab: Learn" track-name="learn" > <span class="devsite-nav-text" tooltip menu="Learn"> More </span> <span class="devsite-nav-icon material-icons" data-icon="forward" menu="Learn"> </span> </span> </li> </ul> </li> <li class="devsite-nav-item"> <a href="/api" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Tab: API" track-name="api" data-category="Site-Wide Custom Events" data-label="Responsive Tab: API" track-type="globalNav" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > API </span> </a> <ul class="devsite-nav-responsive-tabs devsite-nav-has-menu "> <li class="devsite-nav-item"> <span class="devsite-nav-title" tooltip data-category="Site-Wide Custom Events" data-label="Tab: API" track-name="api" > <span class="devsite-nav-text" tooltip menu="API"> More </span> <span class="devsite-nav-icon material-icons" data-icon="forward" menu="API"> </span> </span> </li> </ul> </li> <li class="devsite-nav-item"> <a href="/resources" class="devsite-nav-title gc-analytics-event devsite-nav-active" data-category="Site-Wide Custom Events" data-label="Tab: Resources" track-name="resources" data-category="Site-Wide Custom Events" data-label="Responsive Tab: Resources" track-type="globalNav" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Resources </span> </a> <ul class="devsite-nav-responsive-tabs devsite-nav-has-menu "> <li class="devsite-nav-item"> <span class="devsite-nav-title" tooltip data-category="Site-Wide Custom Events" data-label="Tab: Resources" track-name="resources" > <span class="devsite-nav-text" tooltip menu="Resources"> More </span> <span class="devsite-nav-icon material-icons" data-icon="forward" menu="Resources"> </span> </span> </li> </ul> <ul class="devsite-nav-responsive-tabs"> <li class="devsite-nav-item"> <a href="/text" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Tab: Overview" track-name="overview" data-category="Site-Wide Custom Events" data-label="Responsive Tab: Overview" track-type="globalNav" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Overview </span> </a> </li> <li class="devsite-nav-item"> <a href="/text/tutorials" class="devsite-nav-title gc-analytics-event devsite-nav-has-children devsite-nav-active" data-category="Site-Wide Custom Events" data-label="Tab: Tutorials" track-name="tutorials" data-category="Site-Wide Custom Events" data-label="Responsive Tab: Tutorials" track-type="globalNav" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip menu="_book"> Tutorials </span> <span class="devsite-nav-icon material-icons" data-icon="forward" menu="_book"> </span> </a> </li> <li class="devsite-nav-item"> <a href="/text/guide" class="devsite-nav-title gc-analytics-event devsite-nav-has-children " data-category="Site-Wide Custom Events" data-label="Tab: Guide" track-name="guide" data-category="Site-Wide Custom Events" data-label="Responsive Tab: Guide" track-type="globalNav" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Guide </span> <span class="devsite-nav-icon material-icons" data-icon="forward" > </span> </a> </li> <li class="devsite-nav-item"> <a href="/text/api_overview" class="devsite-nav-title gc-analytics-event devsite-nav-has-children " data-category="Site-Wide Custom Events" data-label="Tab: API" track-name="api" data-category="Site-Wide Custom Events" data-label="Responsive Tab: API" track-type="globalNav" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > API </span> <span class="devsite-nav-icon material-icons" data-icon="forward" > </span> </a> </li> </ul> </li> <li class="devsite-nav-item"> <a href="/community" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Tab: Community" track-name="community" data-category="Site-Wide Custom Events" data-label="Responsive Tab: Community" track-type="globalNav" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Community </span> </a> <ul class="devsite-nav-responsive-tabs devsite-nav-has-menu "> <li class="devsite-nav-item"> <span class="devsite-nav-title" tooltip data-category="Site-Wide Custom Events" data-label="Tab: Community" track-name="community" > <span class="devsite-nav-text" tooltip menu="Community"> More </span> <span class="devsite-nav-icon material-icons" data-icon="forward" menu="Community"> </span> </span> </li> </ul> </li> <li class="devsite-nav-item"> <a href="/about" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Tab: Why TensorFlow" track-name="why tensorflow" data-category="Site-Wide Custom Events" data-label="Responsive Tab: Why TensorFlow" track-type="globalNav" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Why TensorFlow </span> </a> <ul class="devsite-nav-responsive-tabs devsite-nav-has-menu "> <li class="devsite-nav-item"> <span class="devsite-nav-title" tooltip data-category="Site-Wide Custom Events" data-label="Tab: Why TensorFlow" track-name="why tensorflow" > <span class="devsite-nav-text" tooltip menu="Why TensorFlow"> More </span> <span class="devsite-nav-icon material-icons" data-icon="forward" menu="Why TensorFlow"> </span> </span> </li> </ul> </li> <li class="devsite-nav-item"> <a href="//github.com/tensorflow" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: GitHub" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > GitHub </span> </a> </li> </ul> </div> <div class="devsite-mobile-nav-bottom"> <ul class="devsite-nav-list" menu="_book"> <li class="devsite-nav-item"><a href="/text/tutorials" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /text/tutorials" track-type="bookNav" track-name="click" track-metadata-eventdetail="/text/tutorials" ><span class="devsite-nav-text" tooltip>Overview</span></a></li> <li class="devsite-nav-item devsite-nav-divider devsite-nav-heading"><div class="devsite-nav-title devsite-nav-title-no-path"> <span class="devsite-nav-text" tooltip>Keras NLP</span> </div></li> <li class="devsite-nav-item devsite-nav-external"><a href="https://keras.io/guides/keras_nlp/getting_started/" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: https://keras.io/guides/keras_nlp/getting_started/" track-type="bookNav" track-name="click" track-metadata-eventdetail="https://keras.io/guides/keras_nlp/getting_started/" ><span class="devsite-nav-text" tooltip>Get started with KerasNLP</span><span class="devsite-nav-icon material-icons" data-icon="external" data-title="External" aria-hidden="true"></span></a></li> <li class="devsite-nav-item devsite-nav-divider devsite-nav-heading"><div class="devsite-nav-title devsite-nav-title-no-path"> <span class="devsite-nav-text" tooltip>Text Generation</span> </div></li> <li class="devsite-nav-item"><a href="/text/tutorials/text_generation" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /text/tutorials/text_generation" track-type="bookNav" track-name="click" track-metadata-eventdetail="/text/tutorials/text_generation" ><span class="devsite-nav-text" tooltip>Generate Text with RNNs</span></a></li> <li class="devsite-nav-item"><a href="/text/tutorials/nmt_with_attention" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /text/tutorials/nmt_with_attention" track-type="bookNav" track-name="click" track-metadata-eventdetail="/text/tutorials/nmt_with_attention" ><span class="devsite-nav-text" tooltip>Translate text with seq2seq models</span></a></li> <li class="devsite-nav-item"><a href="/text/tutorials/transformer" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /text/tutorials/transformer" track-type="bookNav" track-name="click" track-metadata-eventdetail="/text/tutorials/transformer" ><span class="devsite-nav-text" tooltip>Translate text with transformer models</span></a></li> <li class="devsite-nav-item"><a href="/text/tutorials/image_captioning" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /text/tutorials/image_captioning" track-type="bookNav" track-name="click" track-metadata-eventdetail="/text/tutorials/image_captioning" ><span class="devsite-nav-text" tooltip>Image captioning</span></a></li> <li class="devsite-nav-item devsite-nav-divider devsite-nav-heading"><div class="devsite-nav-title devsite-nav-title-no-path"> <span class="devsite-nav-text" tooltip>Text Classification</span> </div></li> <li class="devsite-nav-item"><a href="/text/tutorials/classify_text_with_bert" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /text/tutorials/classify_text_with_bert" track-type="bookNav" track-name="click" track-metadata-eventdetail="/text/tutorials/classify_text_with_bert" ><span class="devsite-nav-text" tooltip>Text classification with BERT</span></a></li> <li class="devsite-nav-item"><a href="/text/tutorials/text_classification_rnn" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /text/tutorials/text_classification_rnn" track-type="bookNav" track-name="click" track-metadata-eventdetail="/text/tutorials/text_classification_rnn" ><span class="devsite-nav-text" tooltip>Text classification with RNNs</span></a></li> <li class="devsite-nav-item"><a href="/text/tutorials/text_similarity" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /text/tutorials/text_similarity" track-type="bookNav" track-name="click" track-metadata-eventdetail="/text/tutorials/text_similarity" ><span class="devsite-nav-text" tooltip>Compute Similarity Metrics</span></a></li> <li class="devsite-nav-item devsite-nav-divider devsite-nav-heading"><div class="devsite-nav-title devsite-nav-title-no-path"> <span class="devsite-nav-text" tooltip>NLP with BERT</span> </div></li> <li class="devsite-nav-item"><a href="/text/tutorials/bert_glue" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /text/tutorials/bert_glue" track-type="bookNav" track-name="click" track-metadata-eventdetail="/text/tutorials/bert_glue" ><span class="devsite-nav-text" tooltip>Fine Tune Bert on GLUE tasks</span></a></li> <li class="devsite-nav-item"><a href="/tfmodels/nlp/fine_tune_bert" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /tfmodels/nlp/fine_tune_bert" track-type="bookNav" track-name="click" track-metadata-eventdetail="/tfmodels/nlp/fine_tune_bert" ><span class="devsite-nav-text" tooltip>Fine tune BERT</span></a></li> <li class="devsite-nav-item"><a href="/text/tutorials/uncertainty_quantification_with_sngp_bert" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /text/tutorials/uncertainty_quantification_with_sngp_bert" track-type="bookNav" track-name="click" track-metadata-eventdetail="/text/tutorials/uncertainty_quantification_with_sngp_bert" ><span class="devsite-nav-text" tooltip>Quantify uncertainty with BERT</span></a></li> <li class="devsite-nav-item devsite-nav-divider devsite-nav-heading"><div class="devsite-nav-title devsite-nav-title-no-path"> <span class="devsite-nav-text" tooltip>Embeddings</span> </div></li> <li class="devsite-nav-item"><a href="/text/tutorials/word_embeddings" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /text/tutorials/word_embeddings" track-type="bookNav" track-name="click" track-metadata-eventdetail="/text/tutorials/word_embeddings" ><span class="devsite-nav-text" tooltip>Word embeddings</span></a></li> <li class="devsite-nav-item"><a href="/text/tutorials/warmstart_embedding_matrix" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /text/tutorials/warmstart_embedding_matrix" track-type="bookNav" track-name="click" track-metadata-eventdetail="/text/tutorials/warmstart_embedding_matrix" ><span class="devsite-nav-text" tooltip>Warmstarting embeddings</span></a></li> <li class="devsite-nav-item"><a href="/text/tutorials/word2vec" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /text/tutorials/word2vec" track-type="bookNav" track-name="click" track-metadata-eventdetail="/text/tutorials/word2vec" ><span class="devsite-nav-text" tooltip>Word2Vec</span></a></li> </ul> <ul class="devsite-nav-list" menu="Learn" aria-label="Side menu" hidden> <li class="devsite-nav-item"> <a href="/learn" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: Introduction" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Introduction </span> </a> </li> <li class="devsite-nav-item"> <a href="/tutorials" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: Tutorials" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Tutorials </span> </a> </li> <li class="devsite-nav-item"> <a href="/guide" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: Guide" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Guide </span> </a> </li> <li class="devsite-nav-item"> <a href="/resources/learn-ml" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: Learn ML" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Learn ML </span> </a> </li> </ul> <ul class="devsite-nav-list" menu="API" aria-label="Side menu" hidden> <li class="devsite-nav-item"> <a href="/api/stable" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: TensorFlow (v2.16.1)" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > TensorFlow (v2.16.1) </span> </a> </li> <li class="devsite-nav-item"> <a href="/versions" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: Versions…" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Versions… </span> </a> </li> <li class="devsite-nav-item"> <a href="https://js.tensorflow.org/api/latest/" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: TensorFlow.js" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > TensorFlow.js </span> </a> </li> <li class="devsite-nav-item"> <a href="/lite/api_docs" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: TensorFlow Lite" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > TensorFlow Lite </span> </a> </li> <li class="devsite-nav-item"> <a href="/tfx/api_docs" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: TFX" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > TFX </span> </a> </li> </ul> <ul class="devsite-nav-list" menu="Resources" aria-label="Side menu" hidden> <li class="devsite-nav-item devsite-nav-heading"> <span class="devsite-nav-title" tooltip > <span class="devsite-nav-text" tooltip > LIBRARIES </span> </span> </li> <li class="devsite-nav-item"> <a href="/js" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: TensorFlow.js" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > TensorFlow.js </span> </a> </li> <li class="devsite-nav-item"> <a href="/lite" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: TensorFlow Lite" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > TensorFlow Lite </span> </a> </li> <li class="devsite-nav-item"> <a href="/tfx" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: TFX" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > TFX </span> </a> </li> <li class="devsite-nav-item"> <a href="/resources/libraries-extensions" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: All libraries" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > All libraries </span> </a> </li> <li class="devsite-nav-item devsite-nav-heading"> <span class="devsite-nav-title" tooltip > <span class="devsite-nav-text" tooltip > RESOURCES </span> </span> </li> <li class="devsite-nav-item"> <a href="/resources/models-datasets" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: Models & datasets" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Models & datasets </span> </a> </li> <li class="devsite-nav-item"> <a href="/resources/tools" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: Tools" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Tools </span> </a> </li> <li class="devsite-nav-item"> <a href="/responsible_ai" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: Responsible AI" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Responsible AI </span> </a> </li> <li class="devsite-nav-item"> <a href="/resources/recommendation-systems" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: Recommendation systems" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Recommendation systems </span> </a> </li> </ul> <ul class="devsite-nav-list" menu="Community" aria-label="Side menu" hidden> <li class="devsite-nav-item"> <a href="/community/groups" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: Groups" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Groups </span> </a> </li> <li class="devsite-nav-item"> <a href="/community/contribute" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: Contribute" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Contribute </span> </a> </li> <li class="devsite-nav-item"> <a href="https://blog.tensorflow.org/" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: Blog" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Blog </span> </a> </li> <li class="devsite-nav-item"> <a href="https://discuss.tensorflow.org" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: Forum" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Forum </span> </a> </li> </ul> <ul class="devsite-nav-list" menu="Why TensorFlow" aria-label="Side menu" hidden> <li class="devsite-nav-item"> <a href="/about" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: About" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > About </span> </a> </li> <li class="devsite-nav-item"> <a href="/about/case-studies" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: Case studies" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Case studies </span> </a> </li> </ul> </div> </div> </nav> </devsite-book-nav> <section id="gc-wrapper"> <main role="main" class="devsite-main-content" has-book-nav has-sidebar > <div class="devsite-sidebar"> <div class="devsite-sidebar-content"> <devsite-toc class="devsite-nav" role="navigation" aria-label="On this page" depth="2" scrollbars ></devsite-toc> <devsite-recommendations-sidebar class="nocontent devsite-nav"> </devsite-recommendations-sidebar> </div> </div> <devsite-content> <article class="devsite-article"><style> /* Styles inlined from /site-assets/css/style.css */ /* override theme */ table img { max-width: 100%; } /* .devsite-terminal virtualenv prompt */ .tfo-terminal-venv::before { content: "(venv) $ " !important; } /* .devsite-terminal root prompt */ .tfo-terminal-root::before { content: "# " !important; } /* Used in links for type annotations in function/method signatures */ .tfo-signature-link a, .tfo-signature-link a:visited, .tfo-signature-link a:hover, .tfo-signature-link a:focus, .tfo-signature-link a:hover *, .tfo-signature-link a:focus * { text-decoration: none !important; } .tfo-signature-link a, .tfo-signature-link a:visited { border-bottom: 1px dotted #1a73e8; } .tfo-signature-link a:focus { border-bottom-style: solid; } /* .devsite-terminal Windows prompt */ .tfo-terminal-windows::before { content: "C:\\> " !important; } /* .devsite-terminal Windows prompt w/ virtualenv */ .tfo-terminal-windows-venv::before { content: "(venv) C:\\> " !important; } .tfo-diff-green-one-level + * { background: rgba(175, 245, 162, .6) !important; } .tfo-diff-green + * > * { background: rgba(175, 245, 162, .6) !important; } .tfo-diff-green-list + ul > li:first-of-type { background: rgba(175, 245, 162, .6) !important; } .tfo-diff-red-one-level + * { background: rgba(255, 230, 230, .6) !important; text-decoration: line-through !important; } .tfo-diff-red + * > * { background: rgba(255, 230, 230, .6) !important; text-decoration: line-through !important; } .tfo-diff-red-list + ul > li:first-of-type { background: rgba(255, 230, 230, .6) !important; text-decoration: line-through !important; } devsite-code .tfo-notebook-code-cell-output { max-height: 300px; overflow: auto; background: rgba(255, 247, 237, 1); /* orange bg to distinguish from input code cells */ } devsite-code .tfo-notebook-code-cell-output + .devsite-code-buttons-container button { background: rgba(255, 247, 237, .7); /* orange bg to distinguish from input code cells */ } devsite-code[dark-code] .tfo-notebook-code-cell-output { background: rgba(64, 78, 103, 1); /* medium slate */ } devsite-code[dark-code] .tfo-notebook-code-cell-output + .devsite-code-buttons-container button { background: rgba(64, 78, 103, .7); /* medium slate */ } /* override default table styles for notebook buttons */ .devsite-table-wrapper .tfo-notebook-buttons { display: inline-block; margin-left: 3px; width: auto; } .tfo-notebook-buttons td { padding-left: 0; padding-right: 20px; } .tfo-notebook-buttons a, .tfo-notebook-buttons :link, .tfo-notebook-buttons :visited { border-radius: 8px; box-shadow: 0 1px 2px 0 rgba(60, 64, 67, .3), 0 1px 3px 1px rgba(60, 64, 67, .15); color: #202124; padding: 12px 17px; transition: box-shadow 0.2s; } .tfo-notebook-buttons a:hover, .tfo-notebook-buttons a:focus { box-shadow: 0 1px 2px 0 rgba(60, 64, 67, .3), 0 2px 6px 2px rgba(60, 64, 67, .15); } .tfo-notebook-buttons tr { background: 0; border: 0; } /* on rendered notebook page, remove link to webpage since we're already here */ .tfo-notebook-buttons:not(.tfo-api) td:first-child { display: none; } .tfo-notebook-buttons td > a { -webkit-box-align: center; -ms-flex-align: center; align-items: center; display: -webkit-box; display: -ms-flexbox; display: flex; } .tfo-notebook-buttons td > a > img { margin-right: 8px; } /* landing pages */ .tfo-landing-row-item-inset-white { background-color: #fff; padding: 32px; } .tfo-landing-row-item-inset-white ol, .tfo-landing-row-item-inset-white ul { padding-left: 20px; } /* colab callout button */ .colab-callout-row devsite-code { border-radius: 8px 8px 0 0; box-shadow: none; } .colab-callout-footer { background: #e3e4e7; border-radius: 0 0 8px 8px; color: #37474f; padding: 20px; } .colab-callout-row devsite-code[dark-code] + .colab-callout-footer { background: #3f4f66; } .colab-callout-footer > .button { margin-top: 4px; color: #ff5c00; } .colab-callout-footer > a > span { vertical-align: middle; color: #37474f; padding-left: 10px; font-size: 14px; } .colab-callout-row devsite-code[dark-code] + .colab-callout-footer > a > span { color: #fff; } a.colab-button { background: rgba(255, 255, 255, .75); border: solid 1px rgba(0, 0, 0, .08); border-bottom-color: rgba(0, 0, 0, .15); border-radius: 4px; color: #aaa; display: inline-block; font-size: 11px !important; font-weight: 300; line-height: 16px; padding: 4px 8px; text-decoration: none; text-transform: uppercase; } a.colab-button:hover { background: white; border-color: rgba(0, 0, 0, .2); color: #666; } a.colab-button span { background: url(/images/colab_logo_button.svg) no-repeat 1px 1px / 20px; border-radius: 4px; display: inline-block; padding-left: 24px; text-decoration: none; } @media screen and (max-width: 600px) { .tfo-notebook-buttons td { display: block; } } /* guide and tutorials landing page cards and sections */ .tfo-landing-page-card { padding: 16px; box-shadow: 0 0 36px rgba(0,0,0,0.1); border-radius: 10px; } /* Page section headings */ .tfo-landing-page-heading h2, h2.tfo-landing-page-heading { font-family: "Google Sans", sans-serif; color: #425066; font-size: 30px; font-weight: 700; line-height: 40px; } /* Item title headings */ .tfo-landing-page-heading h3, h3.tfo-landing-page-heading, .tfo-landing-page-card h3, h3.tfo-landing-page-card { font-family: "Google Sans", sans-serif; color: #425066; font-size: 20px; font-weight: 500; line-height: 26px; } /* List of tutorials notebooks for subsites */ .tfo-landing-page-resources-ul { padding-left: 15px } .tfo-landing-page-resources-ul > li { margin: 6px 0; } /* Temporary fix to hide product description in header on landing pages */ devsite-header .devsite-product-description { display: none; } </style> <div class="devsite-article-meta nocontent" role="navigation"> <ul class="devsite-breadcrumb-list" aria-label="Breadcrumb"> <li class="devsite-breadcrumb-item "> <a href="https://www.tensorflow.org/" class="devsite-breadcrumb-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Breadcrumbs" data-value="1" track-type="globalNav" track-name="breadcrumb" track-metadata-position="1" track-metadata-eventdetail="TensorFlow" > TensorFlow </a> </li> <li class="devsite-breadcrumb-item "> <div class="devsite-breadcrumb-guillemet material-icons" aria-hidden="true"></div> <a href="https://www.tensorflow.org/resources" class="devsite-breadcrumb-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Breadcrumbs" data-value="2" track-type="globalNav" track-name="breadcrumb" track-metadata-position="2" track-metadata-eventdetail="" > Resources </a> </li> <li class="devsite-breadcrumb-item "> <div class="devsite-breadcrumb-guillemet material-icons" aria-hidden="true"></div> <a href="https://www.tensorflow.org/text" class="devsite-breadcrumb-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Breadcrumbs" data-value="3" track-type="globalNav" track-name="breadcrumb" track-metadata-position="3" track-metadata-eventdetail="Text" > Text </a> </li> <li class="devsite-breadcrumb-item "> <div class="devsite-breadcrumb-guillemet material-icons" aria-hidden="true"></div> <a href="https://www.tensorflow.org/text/tutorials" class="devsite-breadcrumb-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Breadcrumbs" data-value="4" track-type="globalNav" track-name="breadcrumb" track-metadata-position="4" track-metadata-eventdetail="" > Tutorials </a> </li> </ul> <devsite-thumb-rating position="header"> </devsite-thumb-rating> </div> <h1 class="devsite-page-title" tabindex="-1"> Neural machine translation with a Transformer and Keras </h1> <devsite-feature-tooltip ack-key="AckCollectionsBookmarkTooltipDismiss" analytics-category="Site-Wide Custom Events" analytics-action-show="Callout Profile displayed" analytics-action-close="Callout Profile dismissed" analytics-label="Create Collection Callout" class="devsite-page-bookmark-tooltip nocontent" dismiss-button="true" id="devsite-collections-dropdown" dismiss-button-text="Dismiss" close-button-text="Got it"> <devsite-bookmark></devsite-bookmark> <span slot="popout-heading"> Stay organized with collections </span> <span slot="popout-contents"> Save and categorize content based on your preferences. </span> </devsite-feature-tooltip> <div class="devsite-page-title-meta"><devsite-view-release-notes></devsite-view-release-notes></div> <devsite-toc class="devsite-nav" depth="2" devsite-toc-embedded > </devsite-toc> <div class="devsite-article-body clearfix "> <p><devsite-mathjax config="TeX-AMS-MML_SVG"></devsite-mathjax> </p> <!-- DO NOT EDIT! Automatically generated file. --> <style> td { text-align: center; } th { text-align: center; } </style> <table class="tfo-notebook-buttons" align="left"> <td> <a target="_blank" href="https://www.tensorflow.org/text/tutorials/transformer"> <img src="https://www.tensorflow.org/images/tf_logo_32px.png"> View on TensorFlow.org</a> </td> <td> <a target="_blank" href="https://colab.research.google.com/github/tensorflow/text/blob/master/docs/tutorials/transformer.ipynb"> <img src="https://www.tensorflow.org/images/colab_logo_32px.png"> Run in Google Colab</a> </td> <td> <a target="_blank" href="https://github.com/tensorflow/text/blob/master/docs/tutorials/transformer.ipynb"> <img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png"> View source on GitHub</a> </td> <td> <a href="https://storage.googleapis.com/tensorflow_docs/text/docs/tutorials/transformer.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png">Download notebook</a> </td> </table> <p>This tutorial demonstrates how to create and train a <a href="https://developers.google.com/machine-learning/glossary#sequence-to-sequence-task">sequence-to-sequence</a> <a href="https://developers.google.com/machine-learning/glossary#Transformer">Transformer</a> model to translate <a href="https://www.tensorflow.org/datasets/catalog/ted_hrlr_translate#ted_hrlr_translatept_to_en">Portuguese into English</a>. The Transformer was originally proposed in <a href="https://arxiv.org/abs/1706.03762">"Attention is all you need"</a> by Vaswani et al. (2017).</p> <p>Transformers are deep neural networks that replace CNNs and RNNs with <a href="https://developers.google.com/machine-learning/glossary#self-attention">self-attention</a>. Self-attention allows Transformers to easily transmit information across the input sequences.</p> <p>As explained in the <a href="https://ai.googleblog.com/2017/08/transformer-novel-neural-network.html">Google AI Blog post</a>:</p> <blockquote> <p>Neural networks for machine translation typically contain an encoder reading the input sentence and generating a representation of it. A decoder then generates the output sentence word by word while consulting the representation generated by the encoder. The Transformer starts by generating initial representations, or embeddings, for each word... Then, using self-attention, it aggregates information from all of the other words, generating a new representation per word informed by the entire context, represented by the filled balls. This step is then repeated multiple times in parallel for all words, successively generating new representations.</p> </blockquote> <p><img src="https://www.tensorflow.org/images/tutorials/transformer/apply_the_transformer_to_machine_translation.gif" alt="Applying the Transformer to machine translation"></p> <p>Figure 1: Applying the Transformer to machine translation. Source: <a href="https://ai.googleblog.com/2017/08/transformer-novel-neural-network.html">Google AI Blog</a>.</p> <p>That's a lot to digest, the goal of this tutorial is to break it down into easy to understand parts. In this tutorial you will:</p> <ul> <li>Prepare the data.</li> <li>Implement necessary components: <ul> <li>Positional embeddings.</li> <li>Attention layers.</li> <li>The encoder and decoder.</li> </ul></li> <li>Build & train the Transformer.</li> <li>Generate translations.</li> <li>Export the model.</li> </ul> <p>To get the most out of this tutorial, it helps if you know about <a href="https://www.tensorflow.org/text/tutorials/text_generation">the basics of text generation</a> and attention mechanisms. </p> <p>A Transformer is a sequence-to-sequence encoder-decoder model similar to the model in the <a href="https://www.tensorflow.org/text/tutorials/nmt_with_attention">NMT with attention tutorial</a>. A single-layer Transformer takes a little more code to write, but is almost identical to that encoder-decoder RNN model. The only difference is that the RNN layers are replaced with self-attention layers. This tutorial builds a 4-layer Transformer which is larger and more powerful, but not fundamentally more complex.</p> <table> <tr> <th>The <a href=https://www.tensorflow.org/text/tutorials/nmt_with_attention>RNN+Attention model</a></th> <th>A 1-layer transformer</th> </tr> <tr> <td> <img width="411" src="https://www.tensorflow.org/images/tutorials/transformer/RNN+attention-words.png"> </td> <td> <img width="400" src="https://www.tensorflow.org/images/tutorials/transformer/Transformer-1layer-words.png"> </td> </tr> </table> <p>After training the model in this notebook, you will be able to input a Portuguese sentence and return the English translation.</p> <p><img src="https://www.tensorflow.org/images/tutorials/transformer/attention_map_portuguese.png" alt="Attention heatmap"></p> <p>Figure 2: Visualized attention weights that you can generate at the end of this tutorial.</p> <h2 id="why_transformers_are_significant" data-text="Why Transformers are significant" tabindex="-1">Why Transformers are significant</h2> <ul> <li>Transformers excel at modeling sequential data, such as natural language.</li> <li>Unlike <a href="https://www.tensorflow.org/text/tutorials/text_generation">recurrent neural networks (RNNs)</a>, Transformers are parallelizable. This makes them efficient on hardware like GPUs and TPUs. The main reasons is that Transformers replaced recurrence with attention, and computations can happen simultaneously. Layer outputs can be computed in parallel, instead of a series like an RNN.</li> <li>Unlike <a href="https://www.tensorflow.org/guide/keras/rnn">RNNs</a> (such as <a href="https://arxiv.org/abs/1409.3215">seq2seq, 2014</a>) or <a href="https://www.tensorflow.org/tutorials/images/cnn">convolutional neural networks (CNNs)</a> (for example, <a href="https://arxiv.org/abs/1610.10099">ByteNet</a>), Transformers are able to capture distant or long-range contexts and dependencies in the data between distant positions in the input or output sequences. Thus, longer connections can be learned. Attention allows each location to have access to the entire input at each layer, while in RNNs and CNNs, the information needs to pass through many processing steps to move a long distance, which makes it harder to learn.</li> <li>Transformers make no assumptions about the temporal/spatial relationships across the data. This is ideal for processing a set of objects (for example, <a href="https://www.deepmind.com/blog/alphastar-mastering-the-real-time-strategy-game-starcraft-ii">StarCraft units</a>).</li> </ul> <p><img src="https://www.tensorflow.org/images/tutorials/transformer/encoder_self_attention_distribution.png" width="800" alt="Encoder self-attention distribution for the word it from the 5th to the 6th layer of a Transformer trained on English-to-French translation"></p> <p>Figure 3: The encoder self-attention distribution for the word “it” from the 5th to the 6th layer of a Transformer trained on English-to-French translation (one of eight attention heads). Source: <a href="https://ai.googleblog.com/2017/08/transformer-novel-neural-network.html">Google AI Blog</a>.</p> <h2 id="setup" data-text="Setup" tabindex="-1">Setup</h2> <p>Begin by installing <a href="https://tensorflow.org/datasets">TensorFlow Datasets</a> for loading the dataset and <a href="https://www.tensorflow.org/text">TensorFlow Text</a> for text preprocessing:</p> <pre class="prettyprint lang-bsh" translate="no" dir="ltr"> <code class='devsite-terminal' translate="no" dir="ltr"># Install the most re version of TensorFlow to use the improved</code> <code class='devsite-terminal' translate="no" dir="ltr"># masking support for `tf.keras.layers.MultiHeadAttention`.</code> <code class='devsite-terminal' translate="no" dir="ltr">apt install --allow-change-held-packages libcudnn8=8.1.0.77-1+cuda11.2</code> <code class='devsite-terminal' translate="no" dir="ltr">pip uninstall -y -q tensorflow keras tensorflow-estimator tensorflow-text</code> <code class='devsite-terminal' translate="no" dir="ltr">pip install protobuf~=3.20.3</code> <code class='devsite-terminal' translate="no" dir="ltr">pip install -q tensorflow_datasets</code> <code class='devsite-terminal' translate="no" dir="ltr">pip install -q -U tensorflow-text tensorflow</code> </pre> <p>Import the necessary modules:</p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">import logging import time import numpy as np import matplotlib.pyplot as plt import tensorflow_datasets as tfds import tensorflow as tf import tensorflow_text </code></pre> <h2 id="data_handling" data-text="Data handling" tabindex="-1">Data handling</h2> <p>This section downloads the dataset and the subword tokenizer, from <a href="https://www.tensorflow.org/text/guide/subwords_tokenizer">this tutorial</a>, then wraps it all up in a <a href="https://www.tensorflow.org/api_docs/python/tf/data/Dataset"><code translate="no" dir="ltr">tf.data.Dataset</code></a> for training.</p> <p><section class="expandable tfo-display-only-on-site"> <button type="button" class="button-red button expand-control">Toggle section</button></p> <h3 id="download_the_dataset" data-text="Download the dataset" tabindex="-1">Download the dataset</h3> <p>Use TensorFlow Datasets to load the <a href="https://www.tensorflow.org/datasets/catalog/ted_hrlr_translate#ted_hrlr_translatept_to_en">Portuguese-English translation dataset</a>D Talks Open Translation Project. This dataset contains approximately 52,000 training, 1,200 validation and 1,800 test examples.</p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">examples, metadata = tfds.load('ted_hrlr_translate/pt_to_en', with_info=True, as_supervised=True) train_examples, val_examples = examples['train'], examples['validation'] </code></pre> <p>The <a href="https://www.tensorflow.org/api_docs/python/tf/data/Dataset"><code translate="no" dir="ltr">tf.data.Dataset</code></a> object returned by TensorFlow Datasets yields pairs of text examples:</p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">for pt_examples, en_examples in train_examples.batch(3).take(1): print('> Examples in Portuguese:') for pt in pt_examples.numpy(): print(pt.decode('utf-8')) print() print('> Examples in English:') for en in en_examples.numpy(): print(en.decode('utf-8')) </code></pre> <h3 id="set_up_the_tokenizer" data-text="Set up the tokenizer" tabindex="-1">Set up the tokenizer</h3> <p>Now that you have loaded the dataset, you need to tokenize the text, so that each element is represented as a <a href="https://developers.google.com/machine-learning/glossary#token">token</a> or token ID (a numeric representation).</p> <p>Tokenization is the process of breaking up text, into "tokens". Depending on the tokenizer, these tokens can represent sentence-pieces, words, subwords, or characters. To learn more about tokenization, visit <a href="https://www.tensorflow.org/text/guide/tokenizers">this guide</a>.</p> <p>This tutorial uses the tokenizers built in the <a href="https://www.tensorflow.org/text/guide/subwords_tokenizer">subword tokenizer</a> tutorial. That tutorial optimizes two <a href="https://www.tensorflow.org/text/api_docs/python/text/BertTokenizer"><code translate="no" dir="ltr">text.BertTokenizer</code></a> objects (one for English, one for Portuguese) for <strong>this dataset</strong> and exports them in a TensorFlow <code translate="no" dir="ltr">saved_model</code> format.</p> <blockquote> <aside class="note"><strong>Note:</strong><span> This is different from the <a href="https://arxiv.org/pdf/1706.03762.pdf">original paper</a>, section 5.1, where they used a single byte-pair tokenizer for both the source and target with a vocabulary-size of 37000.</span></aside></blockquote> <p>Download, extract, and import the <code translate="no" dir="ltr">saved_model</code>:</p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">model_name = 'ted_hrlr_translate_pt_en_converter' tf.keras.utils.get_file( f'{model_name}.zip', f'https://storage.googleapis.com/download.tensorflow.org/models/{model_name}.zip', cache_dir='.', cache_subdir='', extract=True ) </code></pre><pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">tokenizers = tf.saved_model.load(model_name) </code></pre> <p>The <a href="https://www.tensorflow.org/api_docs/python/tf/saved_model"><code translate="no" dir="ltr">tf.saved_model</code></a> contains two text tokenizers, one for English and one for Portuguese. Both have the same methods:</p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">[item for item in dir(tokenizers.en) if not item.startswith('_')] </code></pre> <p>The <code translate="no" dir="ltr">tokenize</code> method converts a batch of strings to a padded-batch of token IDs. This method splits punctuation, lowercases and unicode-normalizes the input before tokenizing. That standardization is not visible here because the input data is already standardized.</p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">print('> This is a batch of strings:') for en in en_examples.numpy(): print(en.decode('utf-8')) </code></pre><pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">encoded = tokenizers.en.tokenize(en_examples) print('> This is a padded-batch of token IDs:') for row in encoded.to_list(): print(row) </code></pre> <p>The <code translate="no" dir="ltr">detokenize</code> method attempts to convert these token IDs back to human-readable text: </p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">round_trip = tokenizers.en.detokenize(encoded) print('> This is human-readable text:') for line in round_trip.numpy(): print(line.decode('utf-8')) </code></pre> <p>The lower level <code translate="no" dir="ltr">lookup</code> method converts from token-IDs to token text:</p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">print('> This is the text split into tokens:') tokens = tokenizers.en.lookup(encoded) tokens </code></pre> <p>The output demonstrates the "subword" aspect of the subword tokenization.</p> <p>For example, the word <code translate="no" dir="ltr">'searchability'</code> is decomposed into <code translate="no" dir="ltr">'search'</code> and <code translate="no" dir="ltr">'##ability'</code>, and the word <code translate="no" dir="ltr">'serendipity'</code> into <code translate="no" dir="ltr">'s'</code>, <code translate="no" dir="ltr">'##ere'</code>, <code translate="no" dir="ltr">'##nd'</code>, <code translate="no" dir="ltr">'##ip'</code> and <code translate="no" dir="ltr">'##ity'</code>.</p> <p>Note that the tokenized text includes <code translate="no" dir="ltr">'[START]'</code> and <code translate="no" dir="ltr">'[END]'</code> tokens.</p> <p>The distribution of tokens per example in the dataset is as follows:</p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">lengths = [] for pt_examples, en_examples in train_examples.batch(1024): pt_tokens = tokenizers.pt.tokenize(pt_examples) lengths.append(pt_tokens.row_lengths()) en_tokens = tokenizers.en.tokenize(en_examples) lengths.append(en_tokens.row_lengths()) print('.', end='', flush=True) </code></pre><pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">all_lengths = np.concatenate(lengths) plt.hist(all_lengths, np.linspace(0, 500, 101)) plt.ylim(plt.ylim()) max_length = max(all_lengths) plt.plot([max_length, max_length], plt.ylim()) plt.title(f'Maximum tokens per example: {max_length}'); </code></pre> <h3 id="set_up_a_data_pipeline_with_tfdata" data-text="Set up a data pipeline with tf.data" tabindex="-1">Set up a data pipeline with <a href="https://www.tensorflow.org/api_docs/python/tf/data"><code translate="no" dir="ltr">tf.data</code></a></h3> <p>The following function takes batches of text as input, and converts them to a format suitable for training. </p> <ol> <li>It tokenizes them into ragged batches.</li> <li>It trims each to be no longer than <code translate="no" dir="ltr">MAX_TOKENS</code>.</li> <li>It splits the target (English) tokens into inputs and labels. These are shifted by one step so that at each input location the <code translate="no" dir="ltr">label</code> is the id of the next token.</li> <li>It converts the <code translate="no" dir="ltr">RaggedTensor</code>s to padded dense <code translate="no" dir="ltr">Tensor</code>s.</li> <li>It returns an <code translate="no" dir="ltr">(inputs, labels)</code> pair.</li> </ol> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">MAX_TOKENS=128 def prepare_batch(pt, en): pt = tokenizers.pt.tokenize(pt) # Output is ragged. pt = pt[:, :MAX_TOKENS] # Trim to MAX_TOKENS. pt = pt.to_tensor() # Convert to 0-padded dense Tensor en = tokenizers.en.tokenize(en) en = en[:, :(MAX_TOKENS+1)] en_inputs = en[:, :-1].to_tensor() # Drop the [END] tokens en_labels = en[:, 1:].to_tensor() # Drop the [START] tokens return (pt, en_inputs), en_labels </code></pre> <p>The function below converts a dataset of text examples into data of batches for training. </p> <ol> <li>It tokenizes the text, and filters out the sequences that are too long. (The <code translate="no" dir="ltr">batch</code>/<code translate="no" dir="ltr">unbatch</code> is included because the tokenizer is much more efficient on large batches).</li> <li>The <code translate="no" dir="ltr">cache</code> method ensures that that work is only executed once.</li> <li>Then <code translate="no" dir="ltr">shuffle</code> and, <code translate="no" dir="ltr">dense_to_ragged_batch</code> randomize the order and assemble batches of examples. </li> <li>Finally <code translate="no" dir="ltr">prefetch</code> runs the dataset in parallel with the model to ensure that data is available when needed. See <a href="https://www.tensorflow.org/guide/data_performance.ipynb">Better performance with the <code translate="no" dir="ltr">tf.data</code></a> for details.</li> </ol> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">BUFFER_SIZE = 20000 BATCH_SIZE = 64 </code></pre><pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">def make_batches(ds): return ( ds .shuffle(BUFFER_SIZE) .batch(BATCH_SIZE) .map(prepare_batch, tf.data.AUTOTUNE) .prefetch(buffer_size=tf.data.AUTOTUNE)) </code></pre> <p></section></p> <h2 id="test_the_dataset" data-text="Test the Dataset" tabindex="-1">Test the Dataset</h2> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr"># Create training and validation set batches. train_batches = make_batches(train_examples) val_batches = make_batches(val_examples) </code></pre> <p>The resulting <a href="https://www.tensorflow.org/api_docs/python/tf/data/Dataset"><code translate="no" dir="ltr">tf.data.Dataset</code></a> objects are setup for training with Keras. Keras <a href="https://www.tensorflow.org/api_docs/python/tf/keras/Model#fit"><code translate="no" dir="ltr">Model.fit</code></a> training expects <code translate="no" dir="ltr">(inputs, labels)</code> pairs. The <code translate="no" dir="ltr">inputs</code> are pairs of tokenized Portuguese and English sequences, <code translate="no" dir="ltr">(pt, en)</code>. The <code translate="no" dir="ltr">labels</code> are the same English sequences shifted by 1. This shift is so that at each location input <code translate="no" dir="ltr">en</code> sequence, the <code translate="no" dir="ltr">label</code> in the next token.</p> <table> <tr> <th>Inputs at the bottom, labels at the top.</th> </tr> <tr> <td> <img width="400" src="https://www.tensorflow.org/images/tutorials/transformer/Transformer-1layer-words.png"> </td> </tr> </table> <p>This is the same as the <a href="/text/tutorials/text_generation">text generation tutorial</a>, except here you have additional input "context" (the Portuguese sequence) that the model is "conditioned" on.</p> <p>This setup is called "teacher forcing" because regardless of the model's output at each timestep, it gets the true value as input for the next timestep. This is a simple and efficient way to train a text generation model. It's efficient because you don't need to run the model sequentially, the outputs at the different sequence locations can be computed in parallel.</p> <p>You might have expected the <code translate="no" dir="ltr">input, output</code>, pairs to simply be the <code translate="no" dir="ltr">Portuguese, English</code> sequences. Given the Portuguese sequence, the model would try to generate the English sequence.</p> <p>It's possible to train a model that way. You'd need to write out the inference loop and pass the model's output back to the input. It's slower (time steps can't run in parallel), and a harder task to learn (the model can't get the end of a sentence right until it gets the beginning right), but it can give a more stable model because the model has to learn to correct its own errors during training.</p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">for (pt, en), en_labels in train_batches.take(1): break print(pt.shape) print(en.shape) print(en_labels.shape) </code></pre> <p>The <code translate="no" dir="ltr">en</code> and <code translate="no" dir="ltr">en_labels</code> are the same, just shifted by 1:</p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">print(en[0][:10]) print(en_labels[0][:10]) </code></pre> <h2 id="define_the_components" data-text="Define the components" tabindex="-1">Define the components</h2> <p>There's a lot going on inside a Transformer. The important things to remember are:</p> <ol> <li>It follows the same general pattern as a standard sequence-to-sequence model with an encoder and a decoder.</li> <li>If you work through it step by step it will all make sense.</li> </ol> <table> <tr> <th colspan=1>The original Transformer diagram</th> <th colspan=1>A representation of a 4-layer Transformer</th> </tr> <tr> <td> <img width="400" src="https://www.tensorflow.org/images/tutorials/transformer/transformer.png"> </td> <td> <img width="307" src="https://www.tensorflow.org/images/tutorials/transformer/Transformer-4layer-compact.png"> </td> </tr> </table> <p>Each of the components in these two diagrams will be explained as you progress through the tutorial.</p> <h3 id="the_embedding_and_positional_encoding_layer" data-text="The embedding and positional encoding layer" tabindex="-1">The embedding and positional encoding layer</h3> <p>The inputs to both the encoder and decoder use the same embedding and positional encoding logic. </p> <table> <tr> <th colspan=1>The embedding and positional encoding layer</th> <tr> <tr> <td> <img src="https://www.tensorflow.org/images/tutorials/transformer/PositionalEmbedding.png"> </td> </tr> </table> <p>Given a sequence of tokens, both the input tokens (Portuguese) and target tokens (English) have to be converted to vectors using a <a href="https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding"><code translate="no" dir="ltr">tf.keras.layers.Embedding</code></a> layer.</p> <p>The attention layers used throughout the model see their input as a set of vectors, with no order. Since the model doesn't contain any recurrent or convolutional layers. It needs some way to identify word order, otherwise it would see the input sequence as a <a href="https://developers.google.com/machine-learning/glossary#bag-of-words">bag of words</a> instance, <code translate="no" dir="ltr">how are you</code>, <code translate="no" dir="ltr">how you are</code>, <code translate="no" dir="ltr">you how are</code>, and so on, are indistinguishable.</p> <p>A Transformer adds a "Positional Encoding" to the embedding vectors. It uses a set of sines and cosines at different frequencies (across the sequence). By definition nearby elements will have similar position encodings.</p> <p>The original paper uses the following formula for calculating the positional encoding:</p> <p>\[\Large{PE_{(pos, 2i)} = \sin(pos / 10000^{2i / d_{model} })} \]</p> <p>\[\Large{PE_{(pos, 2i+1)} = \cos(pos / 10000^{2i / d_{model} })} \]</p> <aside class="note"><strong>Note:</strong><span> The code below implements it, but instead of interleaving the sines and cosines, the vectors of sines and cosines are simply concatenated. Permuting the channels like this is functionally equivalent, and just a little easier to implement and show in the plots below.</span></aside><pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">def positional_encoding(length, depth): depth = depth/2 positions = np.arange(length)[:, np.newaxis] # (seq, 1) depths = np.arange(depth)[np.newaxis, :]/depth # (1, depth) angle_rates = 1 / (10000**depths) # (1, depth) angle_rads = positions * angle_rates # (pos, depth) pos_encoding = np.concatenate( [np.sin(angle_rads), np.cos(angle_rads)], axis=-1) return tf.cast(pos_encoding, dtype=tf.float32) </code></pre> <p>The position encoding function is a stack of sines and cosines that vibrate at different frequencies depending on their location along the depth of the embedding vector. They vibrate across the position axis.</p> <!-- Toggle section --> <p><section class="expandable"> <button type="button" class="button-red button expand-control">Toggle code</button></p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">pos_encoding = positional_encoding(length=2048, depth=512) # Check the shape. print(pos_encoding.shape) # Plot the dimensions. plt.pcolormesh(pos_encoding.numpy().T, cmap='RdBu') plt.ylabel('Depth') plt.xlabel('Position') plt.colorbar() plt.show() </code></pre> <p></section></p> <p>By definition these vectors align well with nearby vectors along the position axis. Below the position encoding vectors are normalized and the vector from position <code translate="no" dir="ltr">1000</code> is compared, by dot-product, to all the others:</p> <!-- Toggle section --> <p><section class="expandable"> <button type="button" class="button-red button expand-control">Toggle code</button></p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">pos_encoding/=tf.norm(pos_encoding, axis=1, keepdims=True) p = pos_encoding[1000] dots = tf.einsum('pd,d -> p', pos_encoding, p) plt.subplot(2,1,1) plt.plot(dots) plt.ylim([0,1]) plt.plot([950, 950, float('nan'), 1050, 1050], [0,1,float('nan'),0,1], color='k', label='Zoom') plt.legend() plt.subplot(2,1,2) plt.plot(dots) plt.xlim([950, 1050]) plt.ylim([0,1]) </code></pre> <p></section></p> <p>So use this to create a <code translate="no" dir="ltr">PositionEmbedding</code> layer that looks-up a token's embedding vector and adds the position vector:</p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">class PositionalEmbedding(tf.keras.layers.Layer): def __init__(self, vocab_size, d_model): super().__init__() self.d_model = d_model self.embedding = tf.keras.layers.Embedding(vocab_size, d_model, mask_zero=True) self.pos_encoding = positional_encoding(length=2048, depth=d_model) def compute_mask(self, *args, **kwargs): return self.embedding.compute_mask(*args, **kwargs) def call(self, x): length = tf.shape(x)[1] x = self.embedding(x) # This factor sets the relative scale of the embedding and positonal_encoding. x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32)) x = x + self.pos_encoding[tf.newaxis, :length, :] return x </code></pre> <blockquote> <aside class="note"><strong>Note:</strong><span> The <a href="https://arxiv.org/pdf/1706.03762.pdf">original paper</a>, section 3.4 and 5.1, uses a single tokenizer and weight matrix for both the source and target languages. This tutorial uses two separate tokenizers and weight matrices.</span></aside></blockquote> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">embed_pt = PositionalEmbedding(vocab_size=tokenizers.pt.get_vocab_size().numpy(), d_model=512) embed_en = PositionalEmbedding(vocab_size=tokenizers.en.get_vocab_size().numpy(), d_model=512) pt_emb = embed_pt(pt) en_emb = embed_en(en) </code></pre><pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">en_emb._keras_mask </code></pre> <h3 id="add_and_normalize" data-text="Add and normalize" tabindex="-1">Add and normalize</h3> <table> <tr> <th colspan=2>Add and normalize</th> <tr> <tr> <td> <img src="https://www.tensorflow.org/images/tutorials/transformer/Add+Norm.png"> </td> </tr> </table> <p>These "Add & Norm" blocks are scattered throughout the model. Each one joins a residual connection and runs the result through a <code translate="no" dir="ltr">LayerNormalization</code> layer.</p> <p>The easiest way to organize the code is around these residual blocks. The following sections will define custom layer classes for each. </p> <p>The residual "Add & Norm" blocks are included so that training is efficient. The residual connection provides a direct path for the gradient (and ensures that vectors are <strong>updated</strong> by the attention layers instead of <strong>replaced</strong>), while the normalization maintains a reasonable scale for the outputs.</p> <aside class="note"><strong>Note:</strong><span> The implementations, below, use the <code translate="no" dir="ltr">Add</code> layer to ensure that Keras masks are propagated (the <code translate="no" dir="ltr">+</code> operator does not).</span></aside> <h3 id="the_base_attention_layer" data-text="The base attention layer" tabindex="-1">The base attention layer</h3> <p>Attention layers are used throughout the model. These are all identical except for how the attention is configured. Each one contains a <a href="https://www.tensorflow.org/addons/api_docs/python/tfa/layers/MultiHeadAttention"><code translate="no" dir="ltr">layers.MultiHeadAttention</code></a>, a <a href="https://www.tensorflow.org/api_docs/python/tf/keras/layers/LayerNormalization"><code translate="no" dir="ltr">layers.LayerNormalization</code></a> and a <a href="https://www.tensorflow.org/api_docs/python/tf/keras/layers/Add"><code translate="no" dir="ltr">layers.Add</code></a>. </p> <table> <tr> <th colspan=2>The base attention layer</th> <tr> <tr> <td> <img src="https://www.tensorflow.org/images/tutorials/transformer/BaseAttention.png"> </td> </tr> </table> <p>To implement these attention layers, start with a simple base class that just contains the component layers. Each use-case will be implemented as a subclass. It's a little more code to write this way, but it keeps the intention clear.</p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">class BaseAttention(tf.keras.layers.Layer): def __init__(self, **kwargs): super().__init__() self.mha = tf.keras.layers.MultiHeadAttention(**kwargs) self.layernorm = tf.keras.layers.LayerNormalization() self.add = tf.keras.layers.Add() </code></pre> <h4 id="attention_refresher" data-text="Attention refresher" tabindex="-1">Attention refresher</h4> <p>Before you get into the specifics of each usage, here is a quick refresher on how attention works:</p> <table> <tr> <th colspan=1>The base attention layer</th> </tr> <tr> <td> <img width="430" src="https://www.tensorflow.org/images/tutorials/transformer/BaseAttention-new.png"> </td> </tr> </table> <p>There are two inputs:</p> <ol> <li>The query sequence; the sequence being processed; the sequence doing the attending (bottom).</li> <li>The context sequence; the sequence being attended to (left).</li> </ol> <p>The output has the same shape as the query-sequence.</p> <p>The common comparison is that this operation is like a dictionary lookup. A <strong>fuzzy</strong>, <strong>differentiable</strong>, <strong>vectorized</strong> dictionary lookup.</p> <p>Here's a regular python dictionary, with 3 keys and 3 values being passed a single query.</p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">d = {'color': 'blue', 'age': 22, 'type': 'pickup'} result = d['color'] </code></pre> <ul> <li>The <code translate="no" dir="ltr">query</code>s is what you're trying to find.</li> <li>The <code translate="no" dir="ltr">key</code>s what sort of information the dictionary has.</li> <li>The <code translate="no" dir="ltr">value</code> is that information.</li> </ul> <p>When you look up a <code translate="no" dir="ltr">query</code> in a regular dictionary, the dictionary finds the matching <code translate="no" dir="ltr">key</code>, and returns its associated <code translate="no" dir="ltr">value</code>. The <code translate="no" dir="ltr">query</code> either has a matching <code translate="no" dir="ltr">key</code> or it doesn't. You can imagine a <strong>fuzzy</strong> dictionary where the keys don't have to match perfectly. If you looked up <code translate="no" dir="ltr">d["species"]</code> in the dictionary above, maybe you'd want it to return <code translate="no" dir="ltr">"pickup"</code> since that's the best match for the query.</p> <p>An attention layer does a fuzzy lookup like this, but it's not just looking for the best key. It combines the <code translate="no" dir="ltr">values</code> based on how well the <code translate="no" dir="ltr">query</code> matches each <code translate="no" dir="ltr">key</code>.</p> <p>How does that work? In an attention layer the <code translate="no" dir="ltr">query</code>, <code translate="no" dir="ltr">key</code>, and <code translate="no" dir="ltr">value</code> are each vectors. Instead of doing a hash lookup the attention layer combines the <code translate="no" dir="ltr">query</code> and <code translate="no" dir="ltr">key</code> vectors to determine how well they match, the "attention score". The layer returns the average across all the <code translate="no" dir="ltr">values</code>, weighted by the "attention scores".</p> <p>Each location the query-sequence provides a <code translate="no" dir="ltr">query</code> vector. The context sequence acts as the dictionary. At each location in the context sequence provides a <code translate="no" dir="ltr">key</code> and <code translate="no" dir="ltr">value</code> vector. The input vectors are not used directly, the <a href="https://www.tensorflow.org/addons/api_docs/python/tfa/layers/MultiHeadAttention"><code translate="no" dir="ltr">layers.MultiHeadAttention</code></a> layer includes <a href="https://www.tensorflow.org/api_docs/python/tf/keras/layers/Dense"><code translate="no" dir="ltr">layers.Dense</code></a> layers to project the input vectors before using them.</p> <h3 id="the_cross_attention_layer" data-text="The cross attention layer" tabindex="-1">The cross attention layer</h3> <p>At the literal center of the Transformer is the cross-attention layer. This layer connects the encoder and decoder. This layer is the most straight-forward use of attention in the model, it performs the same task as the attention block in the <a href="https://www.tensorflow.org/text/tutorials/nmt_with_attention">NMT with attention tutorial</a>.</p> <table> <tr> <th colspan=1>The cross attention layer</th> <tr> <tr> <td> <img src="https://www.tensorflow.org/images/tutorials/transformer/CrossAttention.png"> </td> </tr> </table> <p>To implement this you pass the target sequence <code translate="no" dir="ltr">x</code> as the <code translate="no" dir="ltr">query</code> and the <code translate="no" dir="ltr">context</code> sequence as the <code translate="no" dir="ltr">key/value</code> when calling the <code translate="no" dir="ltr">mha</code> layer:</p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">class CrossAttention(BaseAttention): def call(self, x, context): attn_output, attn_scores = self.mha( query=x, key=context, value=context, return_attention_scores=True) # Cache the attention scores for plotting later. self.last_attn_scores = attn_scores x = self.add([x, attn_output]) x = self.layernorm(x) return x </code></pre> <p>The caricature below shows how information flows through this layer. The columns represent the weighted sum over the context sequence.</p> <p>For simplicity the residual connections are not shown.</p> <table> <tr> <th>The cross attention layer</th> </tr> <tr> <td> <img width="430" src="https://www.tensorflow.org/images/tutorials/transformer/CrossAttention-new-full.png"> </td> </tr> </table> <p>The output length is the length of the <code translate="no" dir="ltr">query</code> sequence, and not the length of the context <code translate="no" dir="ltr">key/value</code> sequence.</p> <p>The diagram is further simplified, below. There's no need to draw the entire "Attention weights" matrix. The point is that each <code translate="no" dir="ltr">query</code> location can see all the <code translate="no" dir="ltr">key/value</code> pairs in the context, but no information is exchanged between the queries.</p> <table> <tr> <th>Each query sees the whole context.</th> </tr> <tr> <td> <img width="430" src="https://www.tensorflow.org/images/tutorials/transformer/CrossAttention-new.png"> </td> </tr> </table> <p>Test run it on sample inputs:</p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">sample_ca = CrossAttention(num_heads=2, key_dim=512) print(pt_emb.shape) print(en_emb.shape) print(sample_ca(en_emb, pt_emb).shape) </code></pre> <h3 id="the_global_self-attention_layer" data-text="The global self-attention layer" tabindex="-1">The global self-attention layer</h3> <p>This layer is responsible for processing the context sequence, and propagating information along its length:</p> <table> <tr> <th colspan=1>The global self-attention layer</th> <tr> <tr> <td> <img src="https://www.tensorflow.org/images/tutorials/transformer/SelfAttention.png"> </td> </tr> </table> <p>Since the context sequence is fixed while the translation is being generated, information is allowed to flow in both directions. </p> <p>Before Transformers and self-attention, models commonly used RNNs or CNNs to do this task:</p> <table> <tr> <th colspan=1>Bidirectional RNNs and CNNs</th> </tr> <tr> <td> <img width="500" src="https://www.tensorflow.org/images/tutorials/transformer/RNN-bidirectional.png"> </td> </tr> <tr> <td> <img width="500" src="https://www.tensorflow.org/images/tutorials/transformer/CNN.png"> </td> </tr> </table> <p>RNNs and CNNs have their limitations.</p> <ul> <li>The RNN allows information to flow all the way across the sequence, but it passes through many processing steps to get there (limiting gradient flow). These RNN steps have to be run sequentially and so the RNN is less able to take advantage of modern parallel devices.</li> <li>In the CNN each location can be processed in parallel, but it only provides a limited receptive field. The receptive field only grows linearly with the number of CNN layers, You need to stack a number of Convolution layers to transmit information across the sequence (<a href="https://arxiv.org/abs/1609.03499">Wavenet</a> reduces this problem by using dilated convolutions).</li> </ul> <p>The global self-attention layer on the other hand lets every sequence element directly access every other sequence element, with only a few operations, and all the outputs can be computed in parallel. </p> <p>To implement this layer you just need to pass the target sequence, <code translate="no" dir="ltr">x</code>, as both the <code translate="no" dir="ltr">query</code>, and <code translate="no" dir="ltr">value</code> arguments to the <code translate="no" dir="ltr">mha</code> layer: </p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">class GlobalSelfAttention(BaseAttention): def call(self, x): attn_output = self.mha( query=x, value=x, key=x) x = self.add([x, attn_output]) x = self.layernorm(x) return x </code></pre><pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">sample_gsa = GlobalSelfAttention(num_heads=2, key_dim=512) print(pt_emb.shape) print(sample_gsa(pt_emb).shape) </code></pre> <p>Sticking with the same style as before you could draw it like this:</p> <table> <tr> <th colspan=1>The global self-attention layer</th> <tr> <tr> <td> <img width="330" src="https://www.tensorflow.org/images/tutorials/transformer/SelfAttention-new-full.png"> </td> </tr> </table> <p>Again, the residual connections are omitted for clarity.</p> <p>It's more compact, and just as accurate to draw it like this:</p> <table> <tr> <th colspan=1>The global self-attention layer</th> <tr> <tr> <td> <img width="500" src="https://www.tensorflow.org/images/tutorials/transformer/SelfAttention-new.png"> </td> </tr> </table> <h3 id="the_causal_self-attention_layer" data-text="The causal self-attention layer" tabindex="-1">The causal self-attention layer</h3> <p>This layer does a similar job as the global self-attention layer, for the output sequence:</p> <table> <tr> <th colspan=1>The causal self-attention layer</th> <tr> <tr> <td> <img src="https://www.tensorflow.org/images/tutorials/transformer/CausalSelfAttention.png"> </td> </tr> </table> <p>This needs to be handled differently from the encoder's global self-attention layer. </p> <p>Like the <a href="https://www.tensorflow.org/text/tutorials/text_generation">text generation tutorial</a>, and the <a href="https://www.tensorflow.org/text/tutorials/nmt_with_attention">NMT with attention</a> tutorial, Transformers are an "autoregressive" model: They generate the text one token at a time and feed that output back to the input. To make this <em>efficient</em>, these models ensure that the output for each sequence element only depends on the previous sequence elements; the models are "causal".</p> <p>A single-direction RNN is causal by definition. To make a causal convolution you just need to pad the input and shift the output so that it aligns correctly (use <code translate="no" dir="ltr">layers.Conv1D(padding='causal')</code>) .</p> <table> <tr> <th colspan=1>Causal RNNs and CNNs</th> </tr> <tr> <td> <img width="500" src="https://www.tensorflow.org/images/tutorials/transformer/RNN.png"> </td> </tr> <tr> <td> <img width="500" src="https://www.tensorflow.org/images/tutorials/transformer/CNN-causal.png"> </td> </tr> </table> <p>A causal model is efficient in two ways: </p> <ol> <li>In training, it lets you compute loss for every location in the output sequence while executing the model just once.</li> <li>During inference, for each new token generated you only need to calculate its outputs, the outputs for the previous sequence elements can be reused. <ul> <li>For an RNN you just need the RNN-state to account for previous computations (pass <code translate="no" dir="ltr">return_state=True</code> to the RNN layer's constructor).</li> <li>For a CNN you would need to follow the approach of <a href="https://arxiv.org/abs/1611.09482">Fast Wavenet</a></li> </ul></li> </ol> <p>To build a causal self-attention layer, you need to use an appropriate mask when computing the attention scores and summing the attention <code translate="no" dir="ltr">value</code>s.</p> <p>This is taken care of automatically if you pass <code translate="no" dir="ltr">use_causal_mask = True</code> to the <code translate="no" dir="ltr">MultiHeadAttention</code> layer when you call it:</p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">class CausalSelfAttention(BaseAttention): def call(self, x): attn_output = self.mha( query=x, value=x, key=x, use_causal_mask = True) x = self.add([x, attn_output]) x = self.layernorm(x) return x </code></pre> <p>The causal mask ensures that each location only has access to the locations that come before it: </p> <table> <tr> <th colspan=1>The causal self-attention layer</th> <tr> <tr> <td> <img width="330" src="https://www.tensorflow.org/images/tutorials/transformer/CausalSelfAttention-new-full.png"> </td> </tr> </table> <p>Again, the residual connections are omitted for simplicity.</p> <p>The more compact representation of this layer would be:</p> <table> </tr> <th colspan=1>The causal self-attention layer</th> <tr> <tr> <td> <img width="430" src="https://www.tensorflow.org/images/tutorials/transformer/CausalSelfAttention-new.png"> </td> </tr> </table> <p>Test out the layer:</p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">sample_csa = CausalSelfAttention(num_heads=2, key_dim=512) print(en_emb.shape) print(sample_csa(en_emb).shape) </code></pre> <p>The output for early sequence elements doesn't depend on later elements, so it shouldn't matter if you trim elements before or after applying the layer:</p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">out1 = sample_csa(embed_en(en[:, :3])) out2 = sample_csa(embed_en(en))[:, :3] tf.reduce_max(abs(out1 - out2)).numpy() </code></pre><aside class="note"><strong>Note:</strong><span> When using Keras masks, the output values at invalid locations are not well defined. So the above may not hold for masked regions. </span></aside> <h3 id="the_feed_forward_network" data-text="The feed forward network" tabindex="-1">The feed forward network</h3> <p>The transformer also includes this point-wise feed-forward network in both the encoder and decoder:</p> <table> <tr> <th colspan=1>The feed forward network</th> <tr> <tr> <td> <img src="https://www.tensorflow.org/images/tutorials/transformer/FeedForward.png"> </td> </tr> </table> <p>The network consists of two linear layers (<a href="https://www.tensorflow.org/api_docs/python/tf/keras/layers/Dense"><code translate="no" dir="ltr">tf.keras.layers.Dense</code></a>) with a ReLU activation in-between, and a dropout layer. As with the attention layers the code here also includes the residual connection and normalization:</p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">class FeedForward(tf.keras.layers.Layer): def __init__(self, d_model, dff, dropout_rate=0.1): super().__init__() self.seq = tf.keras.Sequential([ tf.keras.layers.Dense(dff, activation='relu'), tf.keras.layers.Dense(d_model), tf.keras.layers.Dropout(dropout_rate) ]) self.add = tf.keras.layers.Add() self.layer_norm = tf.keras.layers.LayerNormalization() def call(self, x): x = self.add([x, self.seq(x)]) x = self.layer_norm(x) return x </code></pre> <p>Test the layer, the output is the same shape as the input:</p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">sample_ffn = FeedForward(512, 2048) print(en_emb.shape) print(sample_ffn(en_emb).shape) </code></pre> <h3 id="the_encoder_layer" data-text="The encoder layer" tabindex="-1">The encoder layer</h3> <p>The encoder contains a stack of <code translate="no" dir="ltr">N</code> encoder layers. Where each <code translate="no" dir="ltr">EncoderLayer</code> contains a <code translate="no" dir="ltr">GlobalSelfAttention</code> and <code translate="no" dir="ltr">FeedForward</code> layer:</p> <table> <tr> <th colspan=1>The encoder layer</th> <tr> <tr> <td> <img src="https://www.tensorflow.org/images/tutorials/transformer/EncoderLayer.png"> </td> </tr> </table> <p>Here is the definition of the <code translate="no" dir="ltr">EncoderLayer</code>:</p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">class EncoderLayer(tf.keras.layers.Layer): def __init__(self,*, d_model, num_heads, dff, dropout_rate=0.1): super().__init__() self.self_attention = GlobalSelfAttention( num_heads=num_heads, key_dim=d_model, dropout=dropout_rate) self.ffn = FeedForward(d_model, dff) def call(self, x): x = self.self_attention(x) x = self.ffn(x) return x </code></pre> <p>And a quick test, the output will have the same shape as the input:</p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">sample_encoder_layer = EncoderLayer(d_model=512, num_heads=8, dff=2048) print(pt_emb.shape) print(sample_encoder_layer(pt_emb).shape) </code></pre> <h3 id="the_encoder" data-text="The encoder" tabindex="-1">The encoder</h3> <p>Next build the encoder.</p> <table> <tr> <th colspan=1>The encoder</th> <tr> <tr> <td> <img src="https://www.tensorflow.org/images/tutorials/transformer/Encoder.png"> </td> </tr> </table> <p>The encoder consists of:</p> <ul> <li>A <code translate="no" dir="ltr">PositionalEmbedding</code> layer at the input.</li> <li>A stack of <code translate="no" dir="ltr">EncoderLayer</code> layers.</li> </ul> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">class Encoder(tf.keras.layers.Layer): def __init__(self, *, num_layers, d_model, num_heads, dff, vocab_size, dropout_rate=0.1): super().__init__() self.d_model = d_model self.num_layers = num_layers self.pos_embedding = PositionalEmbedding( vocab_size=vocab_size, d_model=d_model) self.enc_layers = [ EncoderLayer(d_model=d_model, num_heads=num_heads, dff=dff, dropout_rate=dropout_rate) for _ in range(num_layers)] self.dropout = tf.keras.layers.Dropout(dropout_rate) def call(self, x): # `x` is token-IDs shape: (batch, seq_len) x = self.pos_embedding(x) # Shape `(batch_size, seq_len, d_model)`. # Add dropout. x = self.dropout(x) for i in range(self.num_layers): x = self.enc_layers[i](x) return x # Shape `(batch_size, seq_len, d_model)`. </code></pre> <p>Test the encoder:</p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr"># Instantiate the encoder. sample_encoder = Encoder(num_layers=4, d_model=512, num_heads=8, dff=2048, vocab_size=8500) sample_encoder_output = sample_encoder(pt, training=False) # Print the shape. print(pt.shape) print(sample_encoder_output.shape) # Shape `(batch_size, input_seq_len, d_model)`. </code></pre> <h3 id="the_decoder_layer" data-text="The decoder layer" tabindex="-1">The decoder layer</h3> <p>The decoder's stack is slightly more complex, with each <code translate="no" dir="ltr">DecoderLayer</code> containing a <code translate="no" dir="ltr">CausalSelfAttention</code>, a <code translate="no" dir="ltr">CrossAttention</code>, and a <code translate="no" dir="ltr">FeedForward</code> layer: </p> <table> <tr> <th colspan=1>The decoder layer</th> <tr> <tr> <td> <img src="https://www.tensorflow.org/images/tutorials/transformer/DecoderLayer.png"> </td> </tr> </table> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">class DecoderLayer(tf.keras.layers.Layer): def __init__(self, *, d_model, num_heads, dff, dropout_rate=0.1): super(DecoderLayer, self).__init__() self.causal_self_attention = CausalSelfAttention( num_heads=num_heads, key_dim=d_model, dropout=dropout_rate) self.cross_attention = CrossAttention( num_heads=num_heads, key_dim=d_model, dropout=dropout_rate) self.ffn = FeedForward(d_model, dff) def call(self, x, context): x = self.causal_self_attention(x=x) x = self.cross_attention(x=x, context=context) # Cache the last attention scores for plotting later self.last_attn_scores = self.cross_attention.last_attn_scores x = self.ffn(x) # Shape `(batch_size, seq_len, d_model)`. return x </code></pre> <p>Test the decoder layer:</p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">sample_decoder_layer = DecoderLayer(d_model=512, num_heads=8, dff=2048) sample_decoder_layer_output = sample_decoder_layer( x=en_emb, context=pt_emb) print(en_emb.shape) print(pt_emb.shape) print(sample_decoder_layer_output.shape) # `(batch_size, seq_len, d_model)` </code></pre> <h3 id="the_decoder" data-text="The decoder" tabindex="-1">The decoder</h3> <p>Similar to the <code translate="no" dir="ltr">Encoder</code>, the <code translate="no" dir="ltr">Decoder</code> consists of a <code translate="no" dir="ltr">PositionalEmbedding</code>, and a stack of <code translate="no" dir="ltr">DecoderLayer</code>s:</p> <table> <tr> <th colspan=1>The embedding and positional encoding layer</th> <tr> <tr> <td> <img src="https://www.tensorflow.org/images/tutorials/transformer/Decoder.png"> </td> </tr> </table> <p>Define the decoder by extending <a href="https://www.tensorflow.org/api_docs/python/tf/keras/Layer"><code translate="no" dir="ltr">tf.keras.layers.Layer</code></a>:</p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">class Decoder(tf.keras.layers.Layer): def __init__(self, *, num_layers, d_model, num_heads, dff, vocab_size, dropout_rate=0.1): super(Decoder, self).__init__() self.d_model = d_model self.num_layers = num_layers self.pos_embedding = PositionalEmbedding(vocab_size=vocab_size, d_model=d_model) self.dropout = tf.keras.layers.Dropout(dropout_rate) self.dec_layers = [ DecoderLayer(d_model=d_model, num_heads=num_heads, dff=dff, dropout_rate=dropout_rate) for _ in range(num_layers)] self.last_attn_scores = None def call(self, x, context): # `x` is token-IDs shape (batch, target_seq_len) x = self.pos_embedding(x) # (batch_size, target_seq_len, d_model) x = self.dropout(x) for i in range(self.num_layers): x = self.dec_layers[i](x, context) self.last_attn_scores = self.dec_layers[-1].last_attn_scores # The shape of x is (batch_size, target_seq_len, d_model). return x </code></pre> <p>Test the decoder:</p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr"># Instantiate the decoder. sample_decoder = Decoder(num_layers=4, d_model=512, num_heads=8, dff=2048, vocab_size=8000) output = sample_decoder( x=en, context=pt_emb) # Print the shapes. print(en.shape) print(pt_emb.shape) print(output.shape) </code></pre><pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">sample_decoder.last_attn_scores.shape # (batch, heads, target_seq, input_seq) </code></pre> <p>Having created the Transformer encoder and decoder, it's time to build the Transformer model and train it.</p> <h2 id="the_transformer" data-text="The Transformer" tabindex="-1">The Transformer</h2> <p>You now have <code translate="no" dir="ltr">Encoder</code> and <code translate="no" dir="ltr">Decoder</code>. To complete the <code translate="no" dir="ltr">Transformer</code> model, you need to put them together and add a final linear (<code translate="no" dir="ltr">Dense</code>) layer which converts the resulting vector at each location into output token probabilities. </p> <p>The output of the decoder is the input to this final linear layer.</p> <table> <tr> <th colspan=1>The transformer</th> <tr> <tr> <td> <img src="https://www.tensorflow.org/images/tutorials/transformer/transformer.png"> </td> </tr> </table> <p>A <code translate="no" dir="ltr">Transformer</code> with one layer in both the <code translate="no" dir="ltr">Encoder</code> and <code translate="no" dir="ltr">Decoder</code> looks almost exactly like the model from the <a href="https://www.tensorflow.org/text/tutorials/nmt_with_attention">RNN+attention tutorial</a>. A multi-layer Transformer has more layers, but is fundamentally doing the same thing.</p> <table> <tr> <th colspan=1>A 1-layer transformer</th> <th colspan=1>A 4-layer transformer</th> </tr> <tr> <td> <img width="400" src="https://www.tensorflow.org/images/tutorials/transformer/Transformer-1layer-compact.png"> </td> <td rowspan=3> <img width="330" src="https://www.tensorflow.org/images/tutorials/transformer/Transformer-4layer-compact.png"> </td> </tr> <tr> <th colspan=1>The RNN+Attention model</th> </tr> <tr> <td> <img width="400" src="https://www.tensorflow.org/images/tutorials/transformer/RNN+attention-compact.png"> </td> </tr> </table> <p>Create the <code translate="no" dir="ltr">Transformer</code> by extending <a href="https://www.tensorflow.org/api_docs/python/tf/keras/Model"><code translate="no" dir="ltr">tf.keras.Model</code></a>:</p> <blockquote> <aside class="note"><strong>Note:</strong><span> The <a href="https://arxiv.org/pdf/1706.03762.pdf">original paper</a>, section 3.4, shares the weight matrix between the embedding layer and the final linear layer. To keep things simple, this tutorial uses two separate weight matrices.</span></aside></blockquote> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">class Transformer(tf.keras.Model): def __init__(self, *, num_layers, d_model, num_heads, dff, input_vocab_size, target_vocab_size, dropout_rate=0.1): super().__init__() self.encoder = Encoder(num_layers=num_layers, d_model=d_model, num_heads=num_heads, dff=dff, vocab_size=input_vocab_size, dropout_rate=dropout_rate) self.decoder = Decoder(num_layers=num_layers, d_model=d_model, num_heads=num_heads, dff=dff, vocab_size=target_vocab_size, dropout_rate=dropout_rate) self.final_layer = tf.keras.layers.Dense(target_vocab_size) def call(self, inputs): # To use a Keras model with `.fit` you must pass all your inputs in the # first argument. context, x = inputs context = self.encoder(context) # (batch_size, context_len, d_model) x = self.decoder(x, context) # (batch_size, target_len, d_model) # Final linear layer output. logits = self.final_layer(x) # (batch_size, target_len, target_vocab_size) try: # Drop the keras mask, so it doesn't scale the losses/metrics. # b/250038731 del logits._keras_mask except AttributeError: pass # Return the final output and the attention weights. return logits </code></pre> <h3 id="hyperparameters" data-text="Hyperparameters" tabindex="-1">Hyperparameters</h3> <p>To keep this example small and relatively fast, the number of layers (<code translate="no" dir="ltr">num_layers</code>), the dimensionality of the embeddings (<code translate="no" dir="ltr">d_model</code>), and the internal dimensionality of the <code translate="no" dir="ltr">FeedForward</code> layer (<code translate="no" dir="ltr">dff</code>) have been reduced.</p> <p>The base model described in the original Transformer paper used <code translate="no" dir="ltr">num_layers=6</code>, <code translate="no" dir="ltr">d_model=512</code>, and <code translate="no" dir="ltr">dff=2048</code>.</p> <p>The number of self-attention heads remains the same (<code translate="no" dir="ltr">num_heads=8</code>).</p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">num_layers = 4 d_model = 128 dff = 512 num_heads = 8 dropout_rate = 0.1 </code></pre> <h3 id="try_it_out" data-text="Try it out" tabindex="-1">Try it out</h3> <p>Instantiate the <code translate="no" dir="ltr">Transformer</code> model:</p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">transformer = Transformer( num_layers=num_layers, d_model=d_model, num_heads=num_heads, dff=dff, input_vocab_size=tokenizers.pt.get_vocab_size().numpy(), target_vocab_size=tokenizers.en.get_vocab_size().numpy(), dropout_rate=dropout_rate) </code></pre> <p>Test it:</p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">output = transformer((pt, en)) print(en.shape) print(pt.shape) print(output.shape) </code></pre><pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">attn_scores = transformer.decoder.dec_layers[-1].last_attn_scores print(attn_scores.shape) # (batch, heads, target_seq, input_seq) </code></pre> <p>Print the summary of the model:</p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">transformer.summary() </code></pre> <h2 id="training" data-text="Training" tabindex="-1">Training</h2> <p>It's time to prepare the model and start training it.</p> <h3 id="set_up_the_optimizer" data-text="Set up the optimizer" tabindex="-1">Set up the optimizer</h3> <p>Use the Adam optimizer with a custom learning rate scheduler according to the formula in the original Transformer <a href="https://arxiv.org/abs/1706.03762">paper</a>.</p> <p>\[\Large{lrate = d_{model}^{-0.5} * \min(step{\_}num^{-0.5}, step{\_}num \cdot warmup{\_}steps^{-1.5})}\]</p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule): def __init__(self, d_model, warmup_steps=4000): super().__init__() self.d_model = d_model self.d_model = tf.cast(self.d_model, tf.float32) self.warmup_steps = warmup_steps def __call__(self, step): step = tf.cast(step, dtype=tf.float32) arg1 = tf.math.rsqrt(step) arg2 = step * (self.warmup_steps ** -1.5) return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2) </code></pre> <p>Instantiate the optimizer (in this example it's <a href="https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Adam"><code translate="no" dir="ltr">tf.keras.optimizers.Adam</code></a>):</p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">learning_rate = CustomSchedule(d_model) optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9) </code></pre> <p>Test the custom learning rate scheduler:</p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">plt.plot(learning_rate(tf.range(40000, dtype=tf.float32))) plt.ylabel('Learning Rate') plt.xlabel('Train Step') </code></pre> <h3 id="set_up_the_loss_and_metrics" data-text="Set up the loss and metrics" tabindex="-1">Set up the loss and metrics</h3> <p>Since the target sequences are padded, it is important to apply a padding mask when calculating the loss. Use the cross-entropy loss function (<a href="https://www.tensorflow.org/api_docs/python/tf/keras/losses/SparseCategoricalCrossentropy"><code translate="no" dir="ltr">tf.keras.losses.SparseCategoricalCrossentropy</code></a>):</p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">def masked_loss(label, pred): mask = label != 0 loss_object = tf.keras.losses.SparseCategoricalCrossentropy( from_logits=True, reduction='none') loss = loss_object(label, pred) mask = tf.cast(mask, dtype=loss.dtype) loss *= mask loss = tf.reduce_sum(loss)/tf.reduce_sum(mask) return loss def masked_accuracy(label, pred): pred = tf.argmax(pred, axis=2) label = tf.cast(label, pred.dtype) match = label == pred mask = label != 0 match = match & mask match = tf.cast(match, dtype=tf.float32) mask = tf.cast(mask, dtype=tf.float32) return tf.reduce_sum(match)/tf.reduce_sum(mask) </code></pre> <h3 id="train_the_model" data-text="Train the model" tabindex="-1">Train the model</h3> <p>With all the components ready, configure the training procedure using <code translate="no" dir="ltr">model.compile</code>, and then run it with <code translate="no" dir="ltr">model.fit</code>:</p> <aside class="note"><strong>Note:</strong><span> This takes about an hour to train in Colab.</span></aside><pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">transformer.compile( loss=masked_loss, optimizer=optimizer, metrics=[masked_accuracy]) </code></pre><pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">transformer.fit(train_batches, epochs=20, validation_data=val_batches) </code></pre> <h2 id="run_inference" data-text="Run inference" tabindex="-1">Run inference</h2> <p>You can now test the model by performing a translation. The following steps are used for inference:</p> <ul> <li>Encode the input sentence using the Portuguese tokenizer (<code translate="no" dir="ltr">tokenizers.pt</code>). This is the encoder input.</li> <li>The decoder input is initialized to the <code translate="no" dir="ltr">[START]</code> token.</li> <li>Calculate the padding masks and the look ahead masks.</li> <li>The <code translate="no" dir="ltr">decoder</code> then outputs the predictions by looking at the <code translate="no" dir="ltr">encoder output</code> and its own output (self-attention).</li> <li>Concatenate the predicted token to the decoder input and pass it to the decoder.</li> <li>In this approach, the decoder predicts the next token based on the previous tokens it predicted.</li> </ul> <aside class="note"><strong>Note:</strong><span> The model is optimized for <em>efficient training</em> and makes a next-token prediction for each token in the output simultaneously. This is redundant during inference, and only the last prediction is used. This model can be made more efficient for inference if you only calculate the last prediction when running in inference mode (<code translate="no" dir="ltr">training=False</code>).</span></aside> <p>Define the <code translate="no" dir="ltr">Translator</code> class by subclassing <a href="https://www.tensorflow.org/api_docs/python/tf/Module"><code translate="no" dir="ltr">tf.Module</code></a>:</p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">class Translator(tf.Module): def __init__(self, tokenizers, transformer): self.tokenizers = tokenizers self.transformer = transformer def __call__(self, sentence, max_length=MAX_TOKENS): # The input sentence is Portuguese, hence adding the `[START]` and `[END]` tokens. assert isinstance(sentence, tf.Tensor) if len(sentence.shape) == 0: sentence = sentence[tf.newaxis] sentence = self.tokenizers.pt.tokenize(sentence).to_tensor() encoder_input = sentence # As the output language is English, initialize the output with the # English `[START]` token. start_end = self.tokenizers.en.tokenize([''])[0] start = start_end[0][tf.newaxis] end = start_end[1][tf.newaxis] # `tf.TensorArray` is required here (instead of a Python list), so that the # dynamic-loop can be traced by `tf.function`. output_array = tf.TensorArray(dtype=tf.int64, size=0, dynamic_size=True) output_array = output_array.write(0, start) for i in tf.range(max_length): output = tf.transpose(output_array.stack()) predictions = self.transformer([encoder_input, output], training=False) # Select the last token from the `seq_len` dimension. predictions = predictions[:, -1:, :] # Shape `(batch_size, 1, vocab_size)`. predicted_id = tf.argmax(predictions, axis=-1) # Concatenate the `predicted_id` to the output which is given to the # decoder as its input. output_array = output_array.write(i+1, predicted_id[0]) if predicted_id == end: break output = tf.transpose(output_array.stack()) # The output shape is `(1, tokens)`. text = tokenizers.en.detokenize(output)[0] # Shape: `()`. tokens = tokenizers.en.lookup(output)[0] # `tf.function` prevents us from using the attention_weights that were # calculated on the last iteration of the loop. # So, recalculate them outside the loop. self.transformer([encoder_input, output[:,:-1]], training=False) attention_weights = self.transformer.decoder.last_attn_scores return text, tokens, attention_weights </code></pre><aside class="note"><strong>Note:</strong><span> This function uses an unrolled loop, not a dynamic loop. It generates <code translate="no" dir="ltr">MAX_TOKENS</code> on every call. Refer to the <a href="/text/tutorials/nmt_with_attention">NMT with attention</a> tutorial for an example implementation with a dynamic loop, which can be much more efficient.</span></aside> <p>Create an instance of this <code translate="no" dir="ltr">Translator</code> class, and try it out a few times:</p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">translator = Translator(tokenizers, transformer) </code></pre><pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">def print_translation(sentence, tokens, ground_truth): print(f'{"Input:":15s}: {sentence}') print(f'{"Prediction":15s}: {tokens.numpy().decode("utf-8")}') print(f'{"Ground truth":15s}: {ground_truth}') </code></pre> <p>Example 1:</p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">sentence = 'este é um problema que temos que resolver.' ground_truth = 'this is a problem we have to solve .' translated_text, translated_tokens, attention_weights = translator( tf.constant(sentence)) print_translation(sentence, translated_text, ground_truth) </code></pre> <p>Example 2:</p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">sentence = 'os meus vizinhos ouviram sobre esta ideia.' ground_truth = 'and my neighboring homes heard about this idea .' translated_text, translated_tokens, attention_weights = translator( tf.constant(sentence)) print_translation(sentence, translated_text, ground_truth) </code></pre> <p>Example 3:</p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">sentence = 'vou então muito rapidamente partilhar convosco algumas histórias de algumas coisas mágicas que aconteceram.' ground_truth = "so i'll just share with you some stories very quickly of some magical things that have happened." translated_text, translated_tokens, attention_weights = translator( tf.constant(sentence)) print_translation(sentence, translated_text, ground_truth) </code></pre> <h2 id="create_attention_plots" data-text="Create attention plots" tabindex="-1">Create attention plots</h2> <p>The <code translate="no" dir="ltr">Translator</code> class you created in the previous section returns a dictionary of attention heatmaps you can use to visualize the internal working of the model.</p> <p>For example:</p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">sentence = 'este é o primeiro livro que eu fiz.' ground_truth = "this is the first book i've ever done." translated_text, translated_tokens, attention_weights = translator( tf.constant(sentence)) print_translation(sentence, translated_text, ground_truth) </code></pre> <p>Create a function that plots the attention when a token is generated:</p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">def plot_attention_head(in_tokens, translated_tokens, attention): # The model didn't generate `<START>` in the output. Skip it. translated_tokens = translated_tokens[1:] ax = plt.gca() ax.matshow(attention) ax.set_xticks(range(len(in_tokens))) ax.set_yticks(range(len(translated_tokens))) labels = [label.decode('utf-8') for label in in_tokens.numpy()] ax.set_xticklabels( labels, rotation=90) labels = [label.decode('utf-8') for label in translated_tokens.numpy()] ax.set_yticklabels(labels) </code></pre><pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">head = 0 # Shape: `(batch=1, num_heads, seq_len_q, seq_len_k)`. attention_heads = tf.squeeze(attention_weights, 0) attention = attention_heads[head] attention.shape </code></pre> <p>These are the input (Portuguese) tokens:</p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">in_tokens = tf.convert_to_tensor([sentence]) in_tokens = tokenizers.pt.tokenize(in_tokens).to_tensor() in_tokens = tokenizers.pt.lookup(in_tokens)[0] in_tokens </code></pre> <p>And these are the output (English translation) tokens:</p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">translated_tokens </code></pre><pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">plot_attention_head(in_tokens, translated_tokens, attention) </code></pre><pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">def plot_attention_weights(sentence, translated_tokens, attention_heads): in_tokens = tf.convert_to_tensor([sentence]) in_tokens = tokenizers.pt.tokenize(in_tokens).to_tensor() in_tokens = tokenizers.pt.lookup(in_tokens)[0] fig = plt.figure(figsize=(16, 8)) for h, head in enumerate(attention_heads): ax = fig.add_subplot(2, 4, h+1) plot_attention_head(in_tokens, translated_tokens, head) ax.set_xlabel(f'Head {h+1}') plt.tight_layout() plt.show() </code></pre><pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">plot_attention_weights(sentence, translated_tokens, attention_weights[0]) </code></pre> <p>The model can handle unfamiliar words. Neither <code translate="no" dir="ltr">'triceratops'</code> nor <code translate="no" dir="ltr">'encyclopédia'</code> are in the input dataset, and the model attempts to transliterate them even without a shared vocabulary. For example:</p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">sentence = 'Eu li sobre triceratops na enciclopédia.' ground_truth = 'I read about triceratops in the encyclopedia.' translated_text, translated_tokens, attention_weights = translator( tf.constant(sentence)) print_translation(sentence, translated_text, ground_truth) plot_attention_weights(sentence, translated_tokens, attention_weights[0]) </code></pre> <h2 id="export_the_model" data-text="Export the model" tabindex="-1">Export the model</h2> <p>You have tested the model and the inference is working. Next, you can export it as a <a href="https://www.tensorflow.org/api_docs/python/tf/saved_model"><code translate="no" dir="ltr">tf.saved_model</code></a>. To learn about saving and loading a model in the SavedModel format, use <a href="https://www.tensorflow.org/guide/saved_model">this guide</a>.</p> <p>Create a class called <code translate="no" dir="ltr">ExportTranslator</code> by subclassing the <a href="https://www.tensorflow.org/api_docs/python/tf/Module"><code translate="no" dir="ltr">tf.Module</code></a> subclass with a <a href="https://www.tensorflow.org/api_docs/python/tf/function"><code translate="no" dir="ltr">tf.function</code></a> on the <code translate="no" dir="ltr">__call__</code> method:</p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">class ExportTranslator(tf.Module): def __init__(self, translator): self.translator = translator @tf.function(input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)]) def __call__(self, sentence): (result, tokens, attention_weights) = self.translator(sentence, max_length=MAX_TOKENS) return result </code></pre> <p>In the above <a href="https://www.tensorflow.org/api_docs/python/tf/function"><code translate="no" dir="ltr">tf.function</code></a> only the output sentence is returned. Thanks to the <a href="https://tensorflow.org/guide/intro_to_graphs">non-strict execution</a> in <a href="https://www.tensorflow.org/api_docs/python/tf/function"><code translate="no" dir="ltr">tf.function</code></a> any unnecessary values are never computed.</p> <p>Wrap <code translate="no" dir="ltr">translator</code> in the newly created <code translate="no" dir="ltr">ExportTranslator</code>:</p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">translator = ExportTranslator(translator) </code></pre> <p>Since the model is decoding the predictions using <a href="https://www.tensorflow.org/api_docs/python/tf/math/argmax"><code translate="no" dir="ltr">tf.argmax</code></a> the predictions are deterministic. The original model and one reloaded from its <code translate="no" dir="ltr">SavedModel</code> should give identical predictions:</p> <pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">translator('este é o primeiro livro que eu fiz.').numpy() </code></pre><pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">tf.saved_model.save(translator, export_dir='translator') </code></pre><pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">reloaded = tf.saved_model.load('translator') </code></pre><pre class="prettyprint" translate="no" dir="ltr"><code translate="no" dir="ltr">reloaded('este é o primeiro livro que eu fiz.').numpy() </code></pre> <h2 id="conclusion" data-text="Conclusion" tabindex="-1">Conclusion</h2> <p>In this tutorial you learned about:</p> <ul> <li>The Transformers and their significance in machine learning</li> <li>Attention, self-attention and multi-head attention</li> <li>Positional encoding with embeddings</li> <li>The encoder-decoder architecture of the original Transformer</li> <li>Masking in self-attention</li> <li>How to put it all together to translate text</li> </ul> <p>The downsides of this architecture are:</p> <ul> <li>For a time-series, the output for a time-step is calculated from the <em>entire history</em> instead of only the inputs and current hidden-state. This <em>may</em> be less efficient.</li> <li>If the input has a temporal/spatial relationship, like text or images, some positional encoding must be added or the model will effectively see a bag of words.</li> </ul> <p>If you want to practice, there are many things you could try with it. For example:</p> <ul> <li>Use a different dataset to train the Transformer.</li> <li>Create the "Base Transformer" or "Transformer XL" configurations from the original paper by changing the hyperparameters.</li> <li>Use the layers defined here to create an implementation of <a href="https://arxiv.org/abs/1810.04805">BERT</a></li> <li>Use Beam search to get better predictions.</li> </ul> <p>There are a wide variety of Transformer-based models, many of which improve upon the 2017 version of the original Transformer with encoder-decoder, encoder-only and decoder-only architectures.</p> <p>Some of these models are covered in the following research publications:</p> <ul> <li><a href="https://arxiv.org/abs/2009.06732">"Efficient Transformers: a survey"</a> (Tay et al., 2022)</li> <li><a href="https://arxiv.org/abs/2207.09238">"Formal algorithms for Transformers"</a> (Phuong and Hutter, 2022).</li> <li><a href="https://arxiv.org/abs/1910.10683">T5 ("Exploring the limits of transfer learning with a unified text-to-text Transformer")</a> (Raffel et al., 2019)</li> </ul> <p>You can learn more about other models in the following Google blog posts:</p> <ul> <li><a href="https://ai.googleblog.com/2022/04/pathways-language-model-palm-scaling-to.html">PaLM</a>.</li> <li><a href="https://ai.googleblog.com/2022/01/lamda-towards-safe-grounded-and-high.html">LaMDA</a></li> <li><a href="https://blog.google/products/search/introducing-mum/">MUM</a></li> <li><a href="https://ai.googleblog.com/2020/01/reformer-efficient-transformer.html">Reformer</a></li> <li><a href="https://ai.googleblog.com/2018/11/open-sourcing-bert-state-of-art-pre.html">BERT</a></li> </ul> <p>If you're interested in studying how attention-based models have been applied in tasks outside of natural language processing, check out the following resources:</p> <ul> <li>Vision Transformer (ViT): <a href="https://ai.googleblog.com/2020/12/transformers-for-image-recognition-at.html">Transformers for image recognition at scale</a></li> <li><a href="https://magenta.tensorflow.org/transcription-with-transformers">Multi-task multitrack music transcription (MT3)</a> with a Transformer</li> <li><a href="https://www.deepmind.com/blog/competitive-programming-with-alphacode">Code generation with AlphaCode</a></li> <li><a href="https://ai.googleblog.com/2022/07/training-generalist-agents-with-multi.html">Reinforcement learning with multi-game decision Transformers</a></li> <li><a href="https://www.nature.com/articles/s41586-021-03819-2">Protein structure prediction with AlphaFold</a></li> <li><a href="http://ai.googleblog.com/2022/08/optformer-towards-universal.html">OptFormer: Towards universal hyperparameter optimization with Transformers</a></li> </ul> </div> <devsite-thumb-rating position="footer"> </devsite-thumb-rating> <div class="devsite-floating-action-buttons"> </div> </article> <devsite-content-footer class="nocontent"> <p>Except as otherwise noted, the content of this page is licensed under the <a href="https://creativecommons.org/licenses/by/4.0/">Creative Commons Attribution 4.0 License</a>, and code samples are licensed under the <a href="https://www.apache.org/licenses/LICENSE-2.0">Apache 2.0 License</a>. For details, see the <a href="https://developers.google.com/site-policies">Google Developers Site Policies</a>. Java is a registered trademark of Oracle and/or its affiliates.</p> <p>Last updated 2024-05-31 UTC.</p> </devsite-content-footer> <devsite-notification > </devsite-notification> <div class="devsite-content-data"> <template class="devsite-content-data-template"> [[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2024-05-31 UTC."],[],[]] </template> </div> </devsite-content> </main> <devsite-footer-promos class="devsite-footer"> </devsite-footer-promos> <devsite-footer-linkboxes class="devsite-footer"> <nav class="devsite-footer-linkboxes nocontent" aria-label="Footer links"> <ul class="devsite-footer-linkboxes-list"> <li class="devsite-footer-linkbox "> <h3 class="devsite-footer-linkbox-heading no-link">Stay connected</h3> <ul class="devsite-footer-linkbox-list"> <li class="devsite-footer-linkbox-item"> <a href="//blog.tensorflow.org" class="devsite-footer-linkbox-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Footer Link (index 1)" > Blog </a> </li> <li class="devsite-footer-linkbox-item"> <a href="//discuss.tensorflow.org" class="devsite-footer-linkbox-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Footer Link (index 2)" > Forum </a> </li> <li class="devsite-footer-linkbox-item"> <a href="//github.com/tensorflow/" class="devsite-footer-linkbox-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Footer Link (index 3)" > GitHub </a> </li> <li class="devsite-footer-linkbox-item"> <a href="//twitter.com/tensorflow" class="devsite-footer-linkbox-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Footer Link (index 4)" > Twitter </a> </li> <li class="devsite-footer-linkbox-item"> <a href="//youtube.com/tensorflow" class="devsite-footer-linkbox-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Footer Link (index 5)" > YouTube </a> </li> </ul> </li> <li class="devsite-footer-linkbox "> <h3 class="devsite-footer-linkbox-heading no-link">Support</h3> <ul class="devsite-footer-linkbox-list"> <li class="devsite-footer-linkbox-item"> <a href="//github.com/tensorflow/tensorflow/issues" class="devsite-footer-linkbox-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Footer Link (index 1)" > Issue tracker </a> </li> <li class="devsite-footer-linkbox-item"> <a href="//github.com/tensorflow/tensorflow/blob/master/RELEASE.md" class="devsite-footer-linkbox-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Footer Link (index 2)" > Release notes </a> </li> <li class="devsite-footer-linkbox-item"> <a href="//stackoverflow.com/questions/tagged/tensorflow" class="devsite-footer-linkbox-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Footer Link (index 3)" > Stack Overflow </a> </li> <li class="devsite-footer-linkbox-item"> <a href="/extras/tensorflow_brand_guidelines.pdf" class="devsite-footer-linkbox-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Footer Link (index 4)" > Brand guidelines </a> </li> <li class="devsite-footer-linkbox-item"> <a href="/about/bib" class="devsite-footer-linkbox-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Footer Link (index 5)" > Cite TensorFlow </a> </li> </ul> </li> </ul> </nav> </devsite-footer-linkboxes> <devsite-footer-utility class="devsite-footer"> <div class="devsite-footer-utility nocontent"> <nav class="devsite-footer-utility-links" aria-label="Utility links"> <ul class="devsite-footer-utility-list"> <li class="devsite-footer-utility-item "> <a class="devsite-footer-utility-link gc-analytics-event" href="//policies.google.com/terms" data-category="Site-Wide Custom Events" data-label="Footer Terms link" > Terms </a> </li> <li class="devsite-footer-utility-item "> <a class="devsite-footer-utility-link gc-analytics-event" href="//policies.google.com/privacy" data-category="Site-Wide Custom Events" data-label="Footer Privacy link" > Privacy </a> </li> <li class="devsite-footer-utility-item glue-cookie-notification-bar-control"> <a class="devsite-footer-utility-link gc-analytics-event" href="#" data-category="Site-Wide Custom Events" data-label="Footer Manage cookies link" aria-hidden="true" > Manage cookies </a> </li> <li class="devsite-footer-utility-item devsite-footer-utility-button"> <span class="devsite-footer-utility-description">Sign up for the TensorFlow newsletter</span> <a class="devsite-footer-utility-link gc-analytics-event" href="//www.tensorflow.org/subscribe" data-category="Site-Wide Custom Events" data-label="Footer Subscribe link" > Subscribe </a> </li> </ul> <devsite-language-selector> <ul role="presentation"> <li role="presentation"> <a role="menuitem" lang="en" >English</a> </li> <li role="presentation"> <a role="menuitem" lang="es_419" >Español – América Latina</a> </li> <li role="presentation"> <a role="menuitem" lang="fr" >Français</a> </li> <li role="presentation"> <a role="menuitem" lang="id" >Indonesia</a> </li> <li role="presentation"> <a role="menuitem" lang="it" >Italiano</a> </li> <li role="presentation"> <a role="menuitem" lang="pl" >Polski</a> </li> <li role="presentation"> <a role="menuitem" lang="pt_br" >Português – Brasil</a> </li> <li role="presentation"> <a role="menuitem" lang="vi" >Tiếng Việt</a> </li> <li role="presentation"> <a role="menuitem" lang="tr" >Türkçe</a> </li> <li role="presentation"> <a role="menuitem" lang="ru" >Русский</a> </li> <li role="presentation"> <a role="menuitem" lang="he" >עברית</a> </li> <li role="presentation"> <a role="menuitem" lang="ar" >العربيّة</a> </li> <li role="presentation"> <a role="menuitem" lang="fa" >فارسی</a> </li> <li role="presentation"> <a role="menuitem" lang="hi" >हिंदी</a> </li> <li role="presentation"> <a role="menuitem" lang="bn" >বাংলা</a> </li> <li role="presentation"> <a role="menuitem" lang="th" >ภาษาไทย</a> </li> <li role="presentation"> <a role="menuitem" lang="zh_cn" >中文 – 简体</a> </li> <li role="presentation"> <a role="menuitem" lang="ja" >日本語</a> </li> <li role="presentation"> <a role="menuitem" lang="ko" >한국어</a> </li> </ul> </devsite-language-selector> </nav> </div> </devsite-footer-utility> <devsite-panel></devsite-panel> </section></section> <devsite-sitemask></devsite-sitemask> <devsite-snackbar></devsite-snackbar> <devsite-tooltip ></devsite-tooltip> <devsite-heading-link></devsite-heading-link> <devsite-analytics> <script type="application/json" analytics>[{"dimensions": {"dimension1": "Signed out", "dimension4": "Text", "dimension5": "en", "dimension6": "en", "dimension12": false, "dimension3": false}, "gaid": "UA-69864048-1", "metrics": {"ratings_value": "metric1", "ratings_count": "metric2"}, "purpose": 0}]</script> <script type="application/json" tag-management>{"at": "True", "ga4": [], "ga4p": [], "gtm": [{"id": "GTM-MXSL34P", "purpose": 0}], "parameters": {"internalUser": "False", "language": {"machineTranslated": "False", "requested": "en", "served": "en"}, "pageType": "article", "projectName": "Text", "signedIn": "False", "tenant": "tensorflow", "recommendations": {"sourcePage": "", "sourceType": 0, "sourceRank": 0, "sourceIdenticalDescriptions": 0, "sourceTitleWords": 0, "sourceDescriptionWords": 0, "experiment": ""}, "experiment": {"ids": ""}}}</script> </devsite-analytics> <devsite-badger></devsite-badger> <script nonce="7YE+J3scG0ilxHQ//aFBun03cZ2pa1"> (function(d,e,v,s,i,t,E){d['GoogleDevelopersObject']=i; t=e.createElement(v);t.async=1;t.src=s;E=e.getElementsByTagName(v)[0]; E.parentNode.insertBefore(t,E);})(window, document, 'script', 'https://www.gstatic.com/devrel-devsite/prod/v870e399c64f7c43c99a3043db4b3a74327bb93d0914e84a0c3dba90bbfd67625/tensorflow/js/app_loader.js', '[15,"en",null,"/js/devsite_app_module.js","https://www.gstatic.com/devrel-devsite/prod/v870e399c64f7c43c99a3043db4b3a74327bb93d0914e84a0c3dba90bbfd67625","https://www.gstatic.com/devrel-devsite/prod/v870e399c64f7c43c99a3043db4b3a74327bb93d0914e84a0c3dba90bbfd67625/tensorflow","https://tensorflow-dot-devsite-v2-prod-3p.appspot.com",null,null,["/_pwa/tensorflow/manifest.json","https://www.gstatic.com/devrel-devsite/prod/v870e399c64f7c43c99a3043db4b3a74327bb93d0914e84a0c3dba90bbfd67625/images/video-placeholder.svg","https://www.gstatic.com/devrel-devsite/prod/v870e399c64f7c43c99a3043db4b3a74327bb93d0914e84a0c3dba90bbfd67625/tensorflow/images/favicon.png","https://www.gstatic.com/devrel-devsite/prod/v870e399c64f7c43c99a3043db4b3a74327bb93d0914e84a0c3dba90bbfd67625/tensorflow/images/lockup.svg","https://fonts.googleapis.com/css?family=Google+Sans:400,500|Roboto:400,400italic,500,500italic,700,700italic|Roboto+Mono:400,500,700&display=swap"],1,null,[1,6,8,12,14,17,21,25,50,52,63,70,75,76,80,87,91,92,93,97,98,100,101,102,103,104,105,107,108,109,110,112,113,117,118,120,122,124,125,126,127,129,130,131,132,133,134,135,136,138,140,141,147,148,149,151,152,156,157,158,159,161,163,164,168,169,170,179,180,182,183,186,191,193,196],"AIzaSyCNm9YxQumEXwGJgTDjxoxXK6m1F-9720Q","AIzaSyCc76DZePGtoyUjqKrLdsMGk_ry7sljLbY","www.tensorflow.org","AIzaSyB9bqgQ2t11WJsOX8qNsCQ6U-w91mmqF-I","AIzaSyAdYnStPdzjcJJtQ0mvIaeaMKj7_t6J_Fg",null,null,null,["DevPro__enable_developer_subscriptions","MiscFeatureFlags__enable_project_variables","SignIn__enable_refresh_access_tokens","Cloud__enable_cloud_shell_fte_user_flow","Analytics__enable_clearcut_logging","Cloud__enable_cloud_dlp_service","Cloud__enable_cloudx_ping","Profiles__enable_profile_collections","Concierge__enable_pushui","CloudShell__cloud_code_overflow_menu","Cloud__enable_cloud_shell","Profiles__enable_dashboard_curated_recommendations","Search__enable_dynamic_content_confidential_banner","CloudShell__cloud_shell_button","Profiles__enable_awarding_url","MiscFeatureFlags__developers_footer_dark_image","EngEduTelemetry__enable_engedu_telemetry","MiscFeatureFlags__enable_view_transitions","Cloud__enable_legacy_calculator_redirect","MiscFeatureFlags__emergency_css","Profiles__enable_completecodelab_endpoint","Profiles__enable_developer_profiles_callout","BookNav__enable_tenant_cache_key","MiscFeatureFlags__enable_variable_operator","TpcFeatures__enable_required_headers","Profiles__require_profile_eligibility_for_signin","Search__enable_suggestions_from_borg","Cloud__enable_cloud_facet_chat","Profiles__enable_page_saving","MiscFeatureFlags__enable_explain_this_code","Profiles__enable_release_notes_notifications","Cloud__enable_cloudx_experiment_ids","Search__enable_ai_eligibility_checks","Cloud__enable_free_trial_server_call","DevPro__enable_cloud_innovators_plus","Experiments__reqs_query_experiments","Search__enable_page_map","Cloud__enable_llm_concierge_chat","Profiles__enable_public_developer_profiles","MiscFeatureFlags__developers_footer_image","MiscFeatureFlags__enable_firebase_utm","Profiles__enable_recognition_badges","TpcFeatures__enable_mirror_tenant_redirects","Profiles__enable_complete_playlist_endpoint"],null,null,"AIzaSyA58TaKli1DculwmAmbpzLVGuWc8eCQgQc","https://developerscontentserving-pa.googleapis.com","AIzaSyDWBU60w0P9hEkr29kkksYs8Z7gvZ8u_wc","https://developerscontentsearch-pa.googleapis.com",2,4,null,"https://developerprofiles-pa.googleapis.com",[15,"tensorflow","TensorFlow","www.tensorflow.org",null,"tensorflow-dot-devsite-v2-prod-3p.appspot.com",null,null,[null,1,null,null,null,null,null,null,null,null,null,[1],null,null,null,null,null,null,[1],[1,null,null,[1]],null,null,null,[1,null,1],[1,1,null,1,1]],null,[25,null,null,null,null,null,"/images/lockup.svg","/images/logo.png",null,null,null,1,1,null,null,null,null,null,null,null,null,1,null,null,null,null,[]],[],null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,[6,1],null,[[],[1,1]],[[["UA-69864048-1"],["UA-69864048-4"],null,null,["UA-69864048-5"],["GTM-MXSL34P"],null,null,[["UA-69864048-1",1]],null,[["UA-69864048-5",1]],[["GTM-MXSL34P",1]],1],[[5,4],[12,8],[1,1],[6,5],[4,3],[3,2]],[[2,2],[1,1]]],null,4]]') </script> <devsite-a11y-announce></devsite-a11y-announce> </body> </html>