CINXE.COM
Graph regularization for sentiment classification using synthesized graphs | Neural Structured Learning | TensorFlow
<!doctype html> <html lang="en" dir="ltr"> <head> <meta name="google-signin-client-id" content="157101835696-ooapojlodmuabs2do2vuhhnf90bccmoi.apps.googleusercontent.com"> <meta name="google-signin-scope" content="profile email https://www.googleapis.com/auth/developerprofiles https://www.googleapis.com/auth/developerprofiles.award"> <meta property="og:site_name" content="TensorFlow"> <meta property="og:type" content="website"><meta name="theme-color" content="#ff6f00"><meta charset="utf-8"> <meta content="IE=Edge" http-equiv="X-UA-Compatible"> <meta name="viewport" content="width=device-width, initial-scale=1"> <link rel="manifest" href="/_pwa/tensorflow/manifest.json" crossorigin="use-credentials"> <link rel="preconnect" href="//www.gstatic.com" crossorigin> <link rel="preconnect" href="//fonts.gstatic.com" crossorigin> <link rel="preconnect" href="//fonts.googleapis.com" crossorigin> <link rel="preconnect" href="//apis.google.com" crossorigin> <link rel="preconnect" href="//www.google-analytics.com" crossorigin><link rel="stylesheet" href="//fonts.googleapis.com/css?family=Google+Sans:400,500|Roboto:400,400italic,500,500italic,700,700italic|Roboto+Mono:400,500,700&display=swap"> <link rel="stylesheet" href="//fonts.googleapis.com/css2?family=Material+Icons&family=Material+Symbols+Outlined&display=block"><link rel="stylesheet" href="https://www.gstatic.com/devrel-devsite/prod/v870e399c64f7c43c99a3043db4b3a74327bb93d0914e84a0c3dba90bbfd67625/tensorflow/css/app.css"> <link rel="shortcut icon" href="https://www.gstatic.com/devrel-devsite/prod/v870e399c64f7c43c99a3043db4b3a74327bb93d0914e84a0c3dba90bbfd67625/tensorflow/images/favicon.png"> <link rel="apple-touch-icon" href="https://www.gstatic.com/devrel-devsite/prod/v870e399c64f7c43c99a3043db4b3a74327bb93d0914e84a0c3dba90bbfd67625/tensorflow/images/apple-touch-icon-180x180.png"><link rel="canonical" href="https://www.tensorflow.org/neural_structured_learning/tutorials/graph_keras_lstm_imdb"><link rel="search" type="application/opensearchdescription+xml" title="TensorFlow" href="https://www.tensorflow.org/s/opensearch.xml"> <link rel="alternate" hreflang="en" href="https://www.tensorflow.org/neural_structured_learning/tutorials/graph_keras_lstm_imdb" /><link rel="alternate" hreflang="x-default" href="https://www.tensorflow.org/neural_structured_learning/tutorials/graph_keras_lstm_imdb" /><link rel="alternate" hreflang="ar" href="https://www.tensorflow.org/neural_structured_learning/tutorials/graph_keras_lstm_imdb?hl=ar" /><link rel="alternate" hreflang="bn" href="https://www.tensorflow.org/neural_structured_learning/tutorials/graph_keras_lstm_imdb?hl=bn" /><link rel="alternate" hreflang="zh-Hans" href="https://www.tensorflow.org/neural_structured_learning/tutorials/graph_keras_lstm_imdb?hl=zh-cn" /><link rel="alternate" hreflang="fa" href="https://www.tensorflow.org/neural_structured_learning/tutorials/graph_keras_lstm_imdb?hl=fa" /><link rel="alternate" hreflang="fr" href="https://www.tensorflow.org/neural_structured_learning/tutorials/graph_keras_lstm_imdb?hl=fr" /><link rel="alternate" hreflang="he" href="https://www.tensorflow.org/neural_structured_learning/tutorials/graph_keras_lstm_imdb?hl=he" /><link rel="alternate" hreflang="hi" href="https://www.tensorflow.org/neural_structured_learning/tutorials/graph_keras_lstm_imdb?hl=hi" /><link rel="alternate" hreflang="id" href="https://www.tensorflow.org/neural_structured_learning/tutorials/graph_keras_lstm_imdb?hl=id" /><link rel="alternate" hreflang="it" href="https://www.tensorflow.org/neural_structured_learning/tutorials/graph_keras_lstm_imdb?hl=it" /><link rel="alternate" hreflang="ja" href="https://www.tensorflow.org/neural_structured_learning/tutorials/graph_keras_lstm_imdb?hl=ja" /><link rel="alternate" hreflang="ko" href="https://www.tensorflow.org/neural_structured_learning/tutorials/graph_keras_lstm_imdb?hl=ko" /><link rel="alternate" hreflang="pl" href="https://www.tensorflow.org/neural_structured_learning/tutorials/graph_keras_lstm_imdb?hl=pl" /><link rel="alternate" hreflang="pt-BR" href="https://www.tensorflow.org/neural_structured_learning/tutorials/graph_keras_lstm_imdb?hl=pt-br" /><link rel="alternate" hreflang="ru" href="https://www.tensorflow.org/neural_structured_learning/tutorials/graph_keras_lstm_imdb?hl=ru" /><link rel="alternate" hreflang="es-419" href="https://www.tensorflow.org/neural_structured_learning/tutorials/graph_keras_lstm_imdb?hl=es-419" /><link rel="alternate" hreflang="th" href="https://www.tensorflow.org/neural_structured_learning/tutorials/graph_keras_lstm_imdb?hl=th" /><link rel="alternate" hreflang="tr" href="https://www.tensorflow.org/neural_structured_learning/tutorials/graph_keras_lstm_imdb?hl=tr" /><link rel="alternate" hreflang="vi" href="https://www.tensorflow.org/neural_structured_learning/tutorials/graph_keras_lstm_imdb?hl=vi" /><title>Graph regularization for sentiment classification using synthesized graphs | Neural Structured Learning | TensorFlow</title> <meta property="og:title" content="Graph regularization for sentiment classification using synthesized graphs | Neural Structured Learning | TensorFlow"><meta property="og:url" content="https://www.tensorflow.org/neural_structured_learning/tutorials/graph_keras_lstm_imdb"><meta property="og:image" content="https://www.tensorflow.org/static/images/tf_logo_social.png"> <meta property="og:image:width" content="1200"> <meta property="og:image:height" content="675"><meta property="og:locale" content="en"><meta name="twitter:card" content="summary_large_image"><script type="application/ld+json"> { "@context": "https://schema.org", "@type": "Article", "headline": "Graph regularization for sentiment classification using synthesized graphs" } </script><script type="application/ld+json"> { "@context": "https://schema.org", "@type": "BreadcrumbList", "itemListElement": [{ "@type": "ListItem", "position": 1, "name": "Neural Structured Learning", "item": "https://www.tensorflow.org/neural_structured_learning" },{ "@type": "ListItem", "position": 2, "name": "Graph regularization for sentiment classification using synthesized graphs", "item": "https://www.tensorflow.org/neural_structured_learning/tutorials/graph_keras_lstm_imdb" }] } </script> <link rel="stylesheet" href="/extras.css"></head> <body class="" template="page" theme="tensorflow-theme" type="article" layout="docs" display-toc pending> <devsite-progress type="indeterminate" id="app-progress"></devsite-progress> <section class="devsite-wrapper"> <devsite-cookie-notification-bar></devsite-cookie-notification-bar><devsite-header role="banner"> <div class="devsite-header--inner nocontent"> <div class="devsite-top-logo-row-wrapper-wrapper"> <div class="devsite-top-logo-row-wrapper"> <div class="devsite-top-logo-row"> <button type="button" id="devsite-hamburger-menu" class="devsite-header-icon-button button-flat material-icons gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Navigation menu button" visually-hidden aria-label="Open menu"> </button> <div class="devsite-product-name-wrapper"> <a href="/" class="devsite-site-logo-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Site logo" track-type="globalNav" track-name="tensorFlow" track-metadata-position="nav" track-metadata-eventDetail="nav"> <picture> <img src="https://www.gstatic.com/devrel-devsite/prod/v870e399c64f7c43c99a3043db4b3a74327bb93d0914e84a0c3dba90bbfd67625/tensorflow/images/lockup.svg" class="devsite-site-logo" alt="TensorFlow"> </picture> </a> <span class="devsite-product-name"> <ul class="devsite-breadcrumb-list" > <li class="devsite-breadcrumb-item "> </li> </ul> </span> </div> <div class="devsite-top-logo-row-middle"> <div class="devsite-header-upper-tabs"> <devsite-tabs class="upper-tabs"> <nav class="devsite-tabs-wrapper" aria-label="Upper tabs"> <tab > <a href="https://www.tensorflow.org/install" track-metadata-eventdetail="https://www.tensorflow.org/install" class="devsite-tabs-content gc-analytics-event " track-type="nav" track-metadata-position="nav - install" track-metadata-module="primary nav" data-category="Site-Wide Custom Events" data-label="Tab: Install" track-name="install" > Install </a> </tab> <tab class="devsite-dropdown "> <a href="https://www.tensorflow.org/learn" track-metadata-eventdetail="https://www.tensorflow.org/learn" class="devsite-tabs-content gc-analytics-event " track-type="nav" track-metadata-position="nav - learn" track-metadata-module="primary nav" data-category="Site-Wide Custom Events" data-label="Tab: Learn" track-name="learn" > Learn </a> <a href="#" role="button" aria-haspopup="true" aria-expanded="false" aria-label="Dropdown menu for Learn" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/learn" track-metadata-position="nav - learn" track-metadata-module="primary nav" data-category="Site-Wide Custom Events" data-label="Tab: Learn" track-name="learn" class="devsite-tabs-dropdown-toggle devsite-icon devsite-icon-arrow-drop-down"></a> <div class="devsite-tabs-dropdown" aria-label="submenu" hidden> <div class="devsite-tabs-dropdown-content"> <div class="devsite-tabs-dropdown-column tfo-menu-column-learn"> <ul class="devsite-tabs-dropdown-section "> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/learn" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/learn" track-metadata-position="nav - learn" track-metadata-module="tertiary nav" tooltip > <div class="devsite-nav-item-title"> Introduction </div> <div class="devsite-nav-item-description"> New to TensorFlow? </div> </a> </li> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/tutorials" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/tutorials" track-metadata-position="nav - learn" track-metadata-module="tertiary nav" tooltip > <div class="devsite-nav-item-title"> Tutorials </div> <div class="devsite-nav-item-description"> Learn how to use TensorFlow with end-to-end examples </div> </a> </li> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/guide" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/guide" track-metadata-position="nav - learn" track-metadata-module="tertiary nav" tooltip > <div class="devsite-nav-item-title"> Guide </div> <div class="devsite-nav-item-description"> Learn framework concepts and components </div> </a> </li> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/resources/learn-ml" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/resources/learn-ml" track-metadata-position="nav - learn" track-metadata-module="tertiary nav" tooltip > <div class="devsite-nav-item-title"> Learn ML </div> <div class="devsite-nav-item-description"> Educational resources to master your path with TensorFlow </div> </a> </li> </ul> </div> </div> </div> </tab> <tab class="devsite-dropdown "> <a href="https://www.tensorflow.org/api" track-metadata-eventdetail="https://www.tensorflow.org/api" class="devsite-tabs-content gc-analytics-event " track-type="nav" track-metadata-position="nav - api" track-metadata-module="primary nav" data-category="Site-Wide Custom Events" data-label="Tab: API" track-name="api" > API </a> <a href="#" role="button" aria-haspopup="true" aria-expanded="false" aria-label="Dropdown menu for API" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/api" track-metadata-position="nav - api" track-metadata-module="primary nav" data-category="Site-Wide Custom Events" data-label="Tab: API" track-name="api" class="devsite-tabs-dropdown-toggle devsite-icon devsite-icon-arrow-drop-down"></a> <div class="devsite-tabs-dropdown" aria-label="submenu" hidden> <div class="devsite-tabs-dropdown-content"> <div class="devsite-tabs-dropdown-column "> <ul class="devsite-tabs-dropdown-section "> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/api/stable" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/api/stable" track-metadata-position="nav - api" track-metadata-module="tertiary nav" tooltip > <div class="devsite-nav-item-title"> TensorFlow (v2.16.1) </div> </a> </li> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/versions" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/versions" track-metadata-position="nav - api" track-metadata-module="tertiary nav" tooltip > <div class="devsite-nav-item-title"> Versions… </div> </a> </li> </ul> </div> <div class="devsite-tabs-dropdown-column "> <ul class="devsite-tabs-dropdown-section "> <li class="devsite-nav-item"> <a href="https://js.tensorflow.org/api/latest/" track-type="nav" track-metadata-eventdetail="https://js.tensorflow.org/api/latest/" track-metadata-position="nav - api" track-metadata-module="tertiary nav" tooltip > <div class="devsite-nav-item-title"> TensorFlow.js </div> </a> </li> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/lite/api_docs" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/lite/api_docs" track-metadata-position="nav - api" track-metadata-module="tertiary nav" tooltip > <div class="devsite-nav-item-title"> TensorFlow Lite </div> </a> </li> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/tfx/api_docs" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/tfx/api_docs" track-metadata-position="nav - api" track-metadata-module="tertiary nav" tooltip > <div class="devsite-nav-item-title"> TFX </div> </a> </li> </ul> </div> </div> </div> </tab> <tab class="devsite-dropdown devsite-active "> <a href="https://www.tensorflow.org/resources" track-metadata-eventdetail="https://www.tensorflow.org/resources" class="devsite-tabs-content gc-analytics-event " track-type="nav" track-metadata-position="nav - resources" track-metadata-module="primary nav" aria-label="Resources, selected" data-category="Site-Wide Custom Events" data-label="Tab: Resources" track-name="resources" > Resources </a> <a href="#" role="button" aria-haspopup="true" aria-expanded="false" aria-label="Dropdown menu for Resources" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/resources" track-metadata-position="nav - resources" track-metadata-module="primary nav" data-category="Site-Wide Custom Events" data-label="Tab: Resources" track-name="resources" class="devsite-tabs-dropdown-toggle devsite-icon devsite-icon-arrow-drop-down"></a> <div class="devsite-tabs-dropdown" aria-label="submenu" hidden> <div class="devsite-tabs-dropdown-content"> <div class="devsite-tabs-dropdown-column "> <ul class="devsite-tabs-dropdown-section "> <li class="devsite-nav-title" role="heading" tooltip>LIBRARIES</li> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/js" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/js" track-metadata-position="nav - resources" track-metadata-module="tertiary nav" track-metadata-module_headline="libraries" tooltip > <div class="devsite-nav-item-title"> TensorFlow.js </div> <div class="devsite-nav-item-description"> Develop web ML applications in JavaScript </div> </a> </li> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/lite" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/lite" track-metadata-position="nav - resources" track-metadata-module="tertiary nav" track-metadata-module_headline="libraries" tooltip > <div class="devsite-nav-item-title"> TensorFlow Lite </div> <div class="devsite-nav-item-description"> Deploy ML on mobile, microcontrollers and other edge devices </div> </a> </li> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/tfx" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/tfx" track-metadata-position="nav - resources" track-metadata-module="tertiary nav" track-metadata-module_headline="libraries" tooltip > <div class="devsite-nav-item-title"> TFX </div> <div class="devsite-nav-item-description"> Build production ML pipelines </div> </a> </li> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/resources/libraries-extensions" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/resources/libraries-extensions" track-metadata-position="nav - resources" track-metadata-module="tertiary nav" track-metadata-module_headline="libraries" tooltip > <div class="devsite-nav-item-title"> All libraries </div> <div class="devsite-nav-item-description"> Create advanced models and extend TensorFlow </div> </a> </li> </ul> </div> <div class="devsite-tabs-dropdown-column "> <ul class="devsite-tabs-dropdown-section "> <li class="devsite-nav-title" role="heading" tooltip>RESOURCES</li> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/resources/models-datasets" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/resources/models-datasets" track-metadata-position="nav - resources" track-metadata-module="tertiary nav" track-metadata-module_headline="resources" tooltip > <div class="devsite-nav-item-title"> Models & datasets </div> <div class="devsite-nav-item-description"> Pre-trained models and datasets built by Google and the community </div> </a> </li> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/resources/tools" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/resources/tools" track-metadata-position="nav - resources" track-metadata-module="tertiary nav" track-metadata-module_headline="resources" tooltip > <div class="devsite-nav-item-title"> Tools </div> <div class="devsite-nav-item-description"> Tools to support and accelerate TensorFlow workflows </div> </a> </li> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/responsible_ai" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/responsible_ai" track-metadata-position="nav - resources" track-metadata-module="tertiary nav" track-metadata-module_headline="resources" tooltip > <div class="devsite-nav-item-title"> Responsible AI </div> <div class="devsite-nav-item-description"> Resources for every stage of the ML workflow </div> </a> </li> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/resources/recommendation-systems" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/resources/recommendation-systems" track-metadata-position="nav - resources" track-metadata-module="tertiary nav" track-metadata-module_headline="resources" tooltip > <div class="devsite-nav-item-title"> Recommendation systems </div> <div class="devsite-nav-item-description"> Build recommendation systems with open source tools </div> </a> </li> </ul> </div> </div> </div> </tab> <tab class="devsite-dropdown "> <a href="https://www.tensorflow.org/community" track-metadata-eventdetail="https://www.tensorflow.org/community" class="devsite-tabs-content gc-analytics-event " track-type="nav" track-metadata-position="nav - community" track-metadata-module="primary nav" data-category="Site-Wide Custom Events" data-label="Tab: Community" track-name="community" > Community </a> <a href="#" role="button" aria-haspopup="true" aria-expanded="false" aria-label="Dropdown menu for Community" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/community" track-metadata-position="nav - community" track-metadata-module="primary nav" data-category="Site-Wide Custom Events" data-label="Tab: Community" track-name="community" class="devsite-tabs-dropdown-toggle devsite-icon devsite-icon-arrow-drop-down"></a> <div class="devsite-tabs-dropdown" aria-label="submenu" hidden> <div class="devsite-tabs-dropdown-content"> <div class="devsite-tabs-dropdown-column "> <ul class="devsite-tabs-dropdown-section "> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/community/groups" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/community/groups" track-metadata-position="nav - community" track-metadata-module="tertiary nav" tooltip > <div class="devsite-nav-item-title"> Groups </div> <div class="devsite-nav-item-description"> User groups, interest groups and mailing lists </div> </a> </li> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/community/contribute" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/community/contribute" track-metadata-position="nav - community" track-metadata-module="tertiary nav" tooltip > <div class="devsite-nav-item-title"> Contribute </div> <div class="devsite-nav-item-description"> Guide for contributing to code and documentation </div> </a> </li> <li class="devsite-nav-item"> <a href="https://blog.tensorflow.org/" track-type="nav" track-metadata-eventdetail="https://blog.tensorflow.org/" track-metadata-position="nav - community" track-metadata-module="tertiary nav" tooltip > <div class="devsite-nav-item-title"> Blog </div> <div class="devsite-nav-item-description"> Stay up to date with all things TensorFlow </div> </a> </li> <li class="devsite-nav-item"> <a href="https://discuss.tensorflow.org" track-type="nav" track-metadata-eventdetail="https://discuss.tensorflow.org" track-metadata-position="nav - community" track-metadata-module="tertiary nav" tooltip > <div class="devsite-nav-item-title"> Forum </div> <div class="devsite-nav-item-description"> Discussion platform for the TensorFlow community </div> </a> </li> </ul> </div> </div> </div> </tab> <tab class="devsite-dropdown "> <a href="https://www.tensorflow.org/about" track-metadata-eventdetail="https://www.tensorflow.org/about" class="devsite-tabs-content gc-analytics-event " track-type="nav" track-metadata-position="nav - why tensorflow" track-metadata-module="primary nav" data-category="Site-Wide Custom Events" data-label="Tab: Why TensorFlow" track-name="why tensorflow" > Why TensorFlow </a> <a href="#" role="button" aria-haspopup="true" aria-expanded="false" aria-label="Dropdown menu for Why TensorFlow" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/about" track-metadata-position="nav - why tensorflow" track-metadata-module="primary nav" data-category="Site-Wide Custom Events" data-label="Tab: Why TensorFlow" track-name="why tensorflow" class="devsite-tabs-dropdown-toggle devsite-icon devsite-icon-arrow-drop-down"></a> <div class="devsite-tabs-dropdown" aria-label="submenu" hidden> <div class="devsite-tabs-dropdown-content"> <div class="devsite-tabs-dropdown-column "> <ul class="devsite-tabs-dropdown-section "> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/about" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/about" track-metadata-position="nav - why tensorflow" track-metadata-module="tertiary nav" tooltip > <div class="devsite-nav-item-title"> About </div> </a> </li> <li class="devsite-nav-item"> <a href="https://www.tensorflow.org/about/case-studies" track-type="nav" track-metadata-eventdetail="https://www.tensorflow.org/about/case-studies" track-metadata-position="nav - why tensorflow" track-metadata-module="tertiary nav" tooltip > <div class="devsite-nav-item-title"> Case studies </div> </a> </li> </ul> </div> </div> </div> </tab> </nav> </devsite-tabs> </div> <devsite-search enable-signin enable-search enable-suggestions enable-query-completion project-name="Neural Structured Learning" tenant-name="TensorFlow" > <form class="devsite-search-form" action="https://www.tensorflow.org/s/results" method="GET"> <div class="devsite-search-container"> <button type="button" search-open class="devsite-search-button devsite-header-icon-button button-flat material-icons" aria-label="Open search"></button> <div class="devsite-searchbox"> <input aria-activedescendant="" aria-autocomplete="list" aria-label="Search" aria-expanded="false" aria-haspopup="listbox" autocomplete="off" class="devsite-search-field devsite-search-query" name="q" placeholder="Search" role="combobox" type="text" value="" > <div class="devsite-search-image material-icons" aria-hidden="true"> </div> <div class="devsite-search-shortcut-icon-container" aria-hidden="true"> <kbd class="devsite-search-shortcut-icon">/</kbd> </div> </div> </div> </form> <button type="button" search-close class="devsite-search-button devsite-header-icon-button button-flat material-icons" aria-label="Close search"></button> </devsite-search> </div> <devsite-language-selector> <ul role="presentation"> <li role="presentation"> <a role="menuitem" lang="en" >English</a> </li> <li role="presentation"> <a role="menuitem" lang="es_419" >Español – América Latina</a> </li> <li role="presentation"> <a role="menuitem" lang="fr" >Français</a> </li> <li role="presentation"> <a role="menuitem" lang="id" >Indonesia</a> </li> <li role="presentation"> <a role="menuitem" lang="it" >Italiano</a> </li> <li role="presentation"> <a role="menuitem" lang="pl" >Polski</a> </li> <li role="presentation"> <a role="menuitem" lang="pt_br" >Português – Brasil</a> </li> <li role="presentation"> <a role="menuitem" lang="vi" >Tiếng Việt</a> </li> <li role="presentation"> <a role="menuitem" lang="tr" >Türkçe</a> </li> <li role="presentation"> <a role="menuitem" lang="ru" >Русский</a> </li> <li role="presentation"> <a role="menuitem" lang="he" >עברית</a> </li> <li role="presentation"> <a role="menuitem" lang="ar" >العربيّة</a> </li> <li role="presentation"> <a role="menuitem" lang="fa" >فارسی</a> </li> <li role="presentation"> <a role="menuitem" lang="hi" >हिंदी</a> </li> <li role="presentation"> <a role="menuitem" lang="bn" >বাংলা</a> </li> <li role="presentation"> <a role="menuitem" lang="th" >ภาษาไทย</a> </li> <li role="presentation"> <a role="menuitem" lang="zh_cn" >中文 – 简体</a> </li> <li role="presentation"> <a role="menuitem" lang="ja" >日本語</a> </li> <li role="presentation"> <a role="menuitem" lang="ko" >한국어</a> </li> </ul> </devsite-language-selector> <a class="devsite-header-link devsite-top-button button gc-analytics-event" href="//github.com/tensorflow" data-category="Site-Wide Custom Events" data-label="Site header link" > GitHub </a> <devsite-user enable-profiles id="devsite-user"> <span class="button devsite-top-button" aria-hidden="true" visually-hidden>Sign in</span> </devsite-user> </div> </div> </div> <div class="devsite-collapsible-section "> <div class="devsite-header-background"> <div class="devsite-product-id-row" > <div class="devsite-product-description-row"> <ul class="devsite-breadcrumb-list" > <li class="devsite-breadcrumb-item "> <a href="https://www.tensorflow.org/neural_structured_learning" class="devsite-breadcrumb-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Lower Header" data-value="1" track-type="globalNav" track-name="breadcrumb" track-metadata-position="1" track-metadata-eventdetail="Neural Structured Learning" > Neural Structured Learning </a> </li> </ul> </div> </div> <div class="devsite-doc-set-nav-row"> <devsite-tabs class="lower-tabs"> <nav class="devsite-tabs-wrapper" aria-label="Lower tabs"> <tab > <a href="https://www.tensorflow.org/neural_structured_learning" track-metadata-eventdetail="https://www.tensorflow.org/neural_structured_learning" class="devsite-tabs-content gc-analytics-event " track-type="nav" track-metadata-position="nav - overview" track-metadata-module="primary nav" data-category="Site-Wide Custom Events" data-label="Tab: Overview" track-name="overview" > Overview </a> </tab> <tab class="devsite-active"> <a href="https://www.tensorflow.org/neural_structured_learning/framework" track-metadata-eventdetail="https://www.tensorflow.org/neural_structured_learning/framework" class="devsite-tabs-content gc-analytics-event " track-type="nav" track-metadata-position="nav - guide & tutorials" track-metadata-module="primary nav" aria-label="Guide & Tutorials, selected" data-category="Site-Wide Custom Events" data-label="Tab: Guide & Tutorials" track-name="guide & tutorials" > Guide & Tutorials </a> </tab> <tab > <a href="https://www.tensorflow.org/neural_structured_learning/api_docs/python/nsl" track-metadata-eventdetail="https://www.tensorflow.org/neural_structured_learning/api_docs/python/nsl" class="devsite-tabs-content gc-analytics-event " track-type="nav" track-metadata-position="nav - api" track-metadata-module="primary nav" data-category="Site-Wide Custom Events" data-label="Tab: API" track-name="api" > API </a> </tab> </nav> </devsite-tabs> </div> </div> </div> </div> </devsite-header> <devsite-book-nav scrollbars > <div class="devsite-book-nav-filter" > <span class="filter-list-icon material-icons" aria-hidden="true"></span> <input type="text" placeholder="Filter" aria-label="Type to filter" role="searchbox"> <span class="filter-clear-button hidden" data-title="Clear filter" aria-label="Clear filter" role="button" tabindex="0"></span> </div> <nav class="devsite-book-nav devsite-nav nocontent" aria-label="Side menu"> <div class="devsite-mobile-header"> <button type="button" id="devsite-close-nav" class="devsite-header-icon-button button-flat material-icons gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Close navigation" aria-label="Close navigation"> </button> <div class="devsite-product-name-wrapper"> <a href="/" class="devsite-site-logo-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Site logo" track-type="globalNav" track-name="tensorFlow" track-metadata-position="nav" track-metadata-eventDetail="nav"> <picture> <img src="https://www.gstatic.com/devrel-devsite/prod/v870e399c64f7c43c99a3043db4b3a74327bb93d0914e84a0c3dba90bbfd67625/tensorflow/images/lockup.svg" class="devsite-site-logo" alt="TensorFlow"> </picture> </a> <span class="devsite-product-name"> <ul class="devsite-breadcrumb-list" > <li class="devsite-breadcrumb-item "> </li> </ul> </span> </div> </div> <div class="devsite-book-nav-wrapper"> <div class="devsite-mobile-nav-top"> <ul class="devsite-nav-list"> <li class="devsite-nav-item"> <a href="/install" class="devsite-nav-title gc-analytics-event devsite-nav-has-children " data-category="Site-Wide Custom Events" data-label="Tab: Install" track-name="install" data-category="Site-Wide Custom Events" data-label="Responsive Tab: Install" track-type="globalNav" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Install </span> <span class="devsite-nav-icon material-icons" data-icon="forward" > </span> </a> </li> <li class="devsite-nav-item"> <a href="/learn" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Tab: Learn" track-name="learn" data-category="Site-Wide Custom Events" data-label="Responsive Tab: Learn" track-type="globalNav" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Learn </span> </a> <ul class="devsite-nav-responsive-tabs devsite-nav-has-menu "> <li class="devsite-nav-item"> <span class="devsite-nav-title" tooltip data-category="Site-Wide Custom Events" data-label="Tab: Learn" track-name="learn" > <span class="devsite-nav-text" tooltip menu="Learn"> More </span> <span class="devsite-nav-icon material-icons" data-icon="forward" menu="Learn"> </span> </span> </li> </ul> </li> <li class="devsite-nav-item"> <a href="/api" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Tab: API" track-name="api" data-category="Site-Wide Custom Events" data-label="Responsive Tab: API" track-type="globalNav" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > API </span> </a> <ul class="devsite-nav-responsive-tabs devsite-nav-has-menu "> <li class="devsite-nav-item"> <span class="devsite-nav-title" tooltip data-category="Site-Wide Custom Events" data-label="Tab: API" track-name="api" > <span class="devsite-nav-text" tooltip menu="API"> More </span> <span class="devsite-nav-icon material-icons" data-icon="forward" menu="API"> </span> </span> </li> </ul> </li> <li class="devsite-nav-item"> <a href="/resources" class="devsite-nav-title gc-analytics-event devsite-nav-active" data-category="Site-Wide Custom Events" data-label="Tab: Resources" track-name="resources" data-category="Site-Wide Custom Events" data-label="Responsive Tab: Resources" track-type="globalNav" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Resources </span> </a> <ul class="devsite-nav-responsive-tabs devsite-nav-has-menu "> <li class="devsite-nav-item"> <span class="devsite-nav-title" tooltip data-category="Site-Wide Custom Events" data-label="Tab: Resources" track-name="resources" > <span class="devsite-nav-text" tooltip menu="Resources"> More </span> <span class="devsite-nav-icon material-icons" data-icon="forward" menu="Resources"> </span> </span> </li> </ul> <ul class="devsite-nav-responsive-tabs"> <li class="devsite-nav-item"> <a href="/neural_structured_learning" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Tab: Overview" track-name="overview" data-category="Site-Wide Custom Events" data-label="Responsive Tab: Overview" track-type="globalNav" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Overview </span> </a> </li> <li class="devsite-nav-item"> <a href="/neural_structured_learning/framework" class="devsite-nav-title gc-analytics-event devsite-nav-has-children devsite-nav-active" data-category="Site-Wide Custom Events" data-label="Tab: Guide & Tutorials" track-name="guide & tutorials" data-category="Site-Wide Custom Events" data-label="Responsive Tab: Guide & Tutorials" track-type="globalNav" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip menu="_book"> Guide & Tutorials </span> <span class="devsite-nav-icon material-icons" data-icon="forward" menu="_book"> </span> </a> </li> <li class="devsite-nav-item"> <a href="/neural_structured_learning/api_docs/python/nsl" class="devsite-nav-title gc-analytics-event devsite-nav-has-children " data-category="Site-Wide Custom Events" data-label="Tab: API" track-name="api" data-category="Site-Wide Custom Events" data-label="Responsive Tab: API" track-type="globalNav" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > API </span> <span class="devsite-nav-icon material-icons" data-icon="forward" > </span> </a> </li> </ul> </li> <li class="devsite-nav-item"> <a href="/community" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Tab: Community" track-name="community" data-category="Site-Wide Custom Events" data-label="Responsive Tab: Community" track-type="globalNav" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Community </span> </a> <ul class="devsite-nav-responsive-tabs devsite-nav-has-menu "> <li class="devsite-nav-item"> <span class="devsite-nav-title" tooltip data-category="Site-Wide Custom Events" data-label="Tab: Community" track-name="community" > <span class="devsite-nav-text" tooltip menu="Community"> More </span> <span class="devsite-nav-icon material-icons" data-icon="forward" menu="Community"> </span> </span> </li> </ul> </li> <li class="devsite-nav-item"> <a href="/about" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Tab: Why TensorFlow" track-name="why tensorflow" data-category="Site-Wide Custom Events" data-label="Responsive Tab: Why TensorFlow" track-type="globalNav" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Why TensorFlow </span> </a> <ul class="devsite-nav-responsive-tabs devsite-nav-has-menu "> <li class="devsite-nav-item"> <span class="devsite-nav-title" tooltip data-category="Site-Wide Custom Events" data-label="Tab: Why TensorFlow" track-name="why tensorflow" > <span class="devsite-nav-text" tooltip menu="Why TensorFlow"> More </span> <span class="devsite-nav-icon material-icons" data-icon="forward" menu="Why TensorFlow"> </span> </span> </li> </ul> </li> <li class="devsite-nav-item"> <a href="//github.com/tensorflow" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: GitHub" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > GitHub </span> </a> </li> </ul> </div> <div class="devsite-mobile-nav-bottom"> <ul class="devsite-nav-list" menu="_book"> <li class="devsite-nav-item"><a href="/neural_structured_learning/framework" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /neural_structured_learning/framework" track-type="bookNav" track-name="click" track-metadata-eventdetail="/neural_structured_learning/framework" ><span class="devsite-nav-text" tooltip>Framework</span></a></li> <li class="devsite-nav-item"><a href="/neural_structured_learning/install" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /neural_structured_learning/install" track-type="bookNav" track-name="click" track-metadata-eventdetail="/neural_structured_learning/install" ><span class="devsite-nav-text" tooltip>Install</span></a></li> <li class="devsite-nav-item devsite-nav-heading"><div class="devsite-nav-title devsite-nav-title-no-path"> <span class="devsite-nav-text" tooltip>Neural graph learning tutorials</span> </div></li> <li class="devsite-nav-item"><a href="/neural_structured_learning/tutorials/graph_keras_mlp_cora" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /neural_structured_learning/tutorials/graph_keras_mlp_cora" track-type="bookNav" track-name="click" track-metadata-eventdetail="/neural_structured_learning/tutorials/graph_keras_mlp_cora" ><span class="devsite-nav-text" tooltip>Graph regularization for document classification using natural graphs</span></a></li> <li class="devsite-nav-item"><a href="/neural_structured_learning/tutorials/graph_keras_lstm_imdb" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /neural_structured_learning/tutorials/graph_keras_lstm_imdb" track-type="bookNav" track-name="click" track-metadata-eventdetail="/neural_structured_learning/tutorials/graph_keras_lstm_imdb" ><span class="devsite-nav-text" tooltip>Graph regularization for sentiment classification using synthesized graphs</span></a></li> <li class="devsite-nav-item devsite-nav-heading"><div class="devsite-nav-title devsite-nav-title-no-path"> <span class="devsite-nav-text" tooltip>Adversarial learning tutorials</span> </div></li> <li class="devsite-nav-item"><a href="/neural_structured_learning/tutorials/adversarial_keras_cnn_mnist" class="devsite-nav-title gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Book nav link, pathname: /neural_structured_learning/tutorials/adversarial_keras_cnn_mnist" track-type="bookNav" track-name="click" track-metadata-eventdetail="/neural_structured_learning/tutorials/adversarial_keras_cnn_mnist" ><span class="devsite-nav-text" tooltip>Adversarial regularization for image classification</span></a></li> </ul> <ul class="devsite-nav-list" menu="Learn" aria-label="Side menu" hidden> <li class="devsite-nav-item"> <a href="/learn" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: Introduction" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Introduction </span> </a> </li> <li class="devsite-nav-item"> <a href="/tutorials" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: Tutorials" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Tutorials </span> </a> </li> <li class="devsite-nav-item"> <a href="/guide" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: Guide" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Guide </span> </a> </li> <li class="devsite-nav-item"> <a href="/resources/learn-ml" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: Learn ML" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Learn ML </span> </a> </li> </ul> <ul class="devsite-nav-list" menu="API" aria-label="Side menu" hidden> <li class="devsite-nav-item"> <a href="/api/stable" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: TensorFlow (v2.16.1)" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > TensorFlow (v2.16.1) </span> </a> </li> <li class="devsite-nav-item"> <a href="/versions" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: Versions…" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Versions… </span> </a> </li> <li class="devsite-nav-item"> <a href="https://js.tensorflow.org/api/latest/" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: TensorFlow.js" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > TensorFlow.js </span> </a> </li> <li class="devsite-nav-item"> <a href="/lite/api_docs" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: TensorFlow Lite" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > TensorFlow Lite </span> </a> </li> <li class="devsite-nav-item"> <a href="/tfx/api_docs" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: TFX" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > TFX </span> </a> </li> </ul> <ul class="devsite-nav-list" menu="Resources" aria-label="Side menu" hidden> <li class="devsite-nav-item devsite-nav-heading"> <span class="devsite-nav-title" tooltip > <span class="devsite-nav-text" tooltip > LIBRARIES </span> </span> </li> <li class="devsite-nav-item"> <a href="/js" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: TensorFlow.js" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > TensorFlow.js </span> </a> </li> <li class="devsite-nav-item"> <a href="/lite" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: TensorFlow Lite" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > TensorFlow Lite </span> </a> </li> <li class="devsite-nav-item"> <a href="/tfx" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: TFX" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > TFX </span> </a> </li> <li class="devsite-nav-item"> <a href="/resources/libraries-extensions" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: All libraries" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > All libraries </span> </a> </li> <li class="devsite-nav-item devsite-nav-heading"> <span class="devsite-nav-title" tooltip > <span class="devsite-nav-text" tooltip > RESOURCES </span> </span> </li> <li class="devsite-nav-item"> <a href="/resources/models-datasets" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: Models & datasets" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Models & datasets </span> </a> </li> <li class="devsite-nav-item"> <a href="/resources/tools" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: Tools" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Tools </span> </a> </li> <li class="devsite-nav-item"> <a href="/responsible_ai" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: Responsible AI" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Responsible AI </span> </a> </li> <li class="devsite-nav-item"> <a href="/resources/recommendation-systems" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: Recommendation systems" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Recommendation systems </span> </a> </li> </ul> <ul class="devsite-nav-list" menu="Community" aria-label="Side menu" hidden> <li class="devsite-nav-item"> <a href="/community/groups" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: Groups" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Groups </span> </a> </li> <li class="devsite-nav-item"> <a href="/community/contribute" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: Contribute" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Contribute </span> </a> </li> <li class="devsite-nav-item"> <a href="https://blog.tensorflow.org/" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: Blog" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Blog </span> </a> </li> <li class="devsite-nav-item"> <a href="https://discuss.tensorflow.org" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: Forum" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Forum </span> </a> </li> </ul> <ul class="devsite-nav-list" menu="Why TensorFlow" aria-label="Side menu" hidden> <li class="devsite-nav-item"> <a href="/about" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: About" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > About </span> </a> </li> <li class="devsite-nav-item"> <a href="/about/case-studies" class="devsite-nav-title gc-analytics-event " data-category="Site-Wide Custom Events" data-label="Responsive Tab: Case studies" track-type="navMenu" track-metadata-eventDetail="globalMenu" track-metadata-position="nav"> <span class="devsite-nav-text" tooltip > Case studies </span> </a> </li> </ul> </div> </div> </nav> </devsite-book-nav> <section id="gc-wrapper"> <main role="main" class="devsite-main-content" has-book-nav has-sidebar > <div class="devsite-sidebar"> <div class="devsite-sidebar-content"> <devsite-toc class="devsite-nav" role="navigation" aria-label="On this page" depth="2" scrollbars ></devsite-toc> <devsite-recommendations-sidebar class="nocontent devsite-nav"> </devsite-recommendations-sidebar> </div> </div> <devsite-content> <article class="devsite-article"><style> /* Styles inlined from /site-assets/css/style.css */ /* override theme */ table img { max-width: 100%; } /* .devsite-terminal virtualenv prompt */ .tfo-terminal-venv::before { content: "(venv) $ " !important; } /* .devsite-terminal root prompt */ .tfo-terminal-root::before { content: "# " !important; } /* Used in links for type annotations in function/method signatures */ .tfo-signature-link a, .tfo-signature-link a:visited, .tfo-signature-link a:hover, .tfo-signature-link a:focus, .tfo-signature-link a:hover *, .tfo-signature-link a:focus * { text-decoration: none !important; } .tfo-signature-link a, .tfo-signature-link a:visited { border-bottom: 1px dotted #1a73e8; } .tfo-signature-link a:focus { border-bottom-style: solid; } /* .devsite-terminal Windows prompt */ .tfo-terminal-windows::before { content: "C:\\> " !important; } /* .devsite-terminal Windows prompt w/ virtualenv */ .tfo-terminal-windows-venv::before { content: "(venv) C:\\> " !important; } .tfo-diff-green-one-level + * { background: rgba(175, 245, 162, .6) !important; } .tfo-diff-green + * > * { background: rgba(175, 245, 162, .6) !important; } .tfo-diff-green-list + ul > li:first-of-type { background: rgba(175, 245, 162, .6) !important; } .tfo-diff-red-one-level + * { background: rgba(255, 230, 230, .6) !important; text-decoration: line-through !important; } .tfo-diff-red + * > * { background: rgba(255, 230, 230, .6) !important; text-decoration: line-through !important; } .tfo-diff-red-list + ul > li:first-of-type { background: rgba(255, 230, 230, .6) !important; text-decoration: line-through !important; } devsite-code .tfo-notebook-code-cell-output { max-height: 300px; overflow: auto; background: rgba(255, 247, 237, 1); /* orange bg to distinguish from input code cells */ } devsite-code .tfo-notebook-code-cell-output + .devsite-code-buttons-container button { background: rgba(255, 247, 237, .7); /* orange bg to distinguish from input code cells */ } devsite-code[dark-code] .tfo-notebook-code-cell-output { background: rgba(64, 78, 103, 1); /* medium slate */ } devsite-code[dark-code] .tfo-notebook-code-cell-output + .devsite-code-buttons-container button { background: rgba(64, 78, 103, .7); /* medium slate */ } /* override default table styles for notebook buttons */ .devsite-table-wrapper .tfo-notebook-buttons { display: inline-block; margin-left: 3px; width: auto; } .tfo-notebook-buttons td { padding-left: 0; padding-right: 20px; } .tfo-notebook-buttons a, .tfo-notebook-buttons :link, .tfo-notebook-buttons :visited { border-radius: 8px; box-shadow: 0 1px 2px 0 rgba(60, 64, 67, .3), 0 1px 3px 1px rgba(60, 64, 67, .15); color: #202124; padding: 12px 17px; transition: box-shadow 0.2s; } .tfo-notebook-buttons a:hover, .tfo-notebook-buttons a:focus { box-shadow: 0 1px 2px 0 rgba(60, 64, 67, .3), 0 2px 6px 2px rgba(60, 64, 67, .15); } .tfo-notebook-buttons tr { background: 0; border: 0; } /* on rendered notebook page, remove link to webpage since we're already here */ .tfo-notebook-buttons:not(.tfo-api) td:first-child { display: none; } .tfo-notebook-buttons td > a { -webkit-box-align: center; -ms-flex-align: center; align-items: center; display: -webkit-box; display: -ms-flexbox; display: flex; } .tfo-notebook-buttons td > a > img { margin-right: 8px; } /* landing pages */ .tfo-landing-row-item-inset-white { background-color: #fff; padding: 32px; } .tfo-landing-row-item-inset-white ol, .tfo-landing-row-item-inset-white ul { padding-left: 20px; } /* colab callout button */ .colab-callout-row devsite-code { border-radius: 8px 8px 0 0; box-shadow: none; } .colab-callout-footer { background: #e3e4e7; border-radius: 0 0 8px 8px; color: #37474f; padding: 20px; } .colab-callout-row devsite-code[dark-code] + .colab-callout-footer { background: #3f4f66; } .colab-callout-footer > .button { margin-top: 4px; color: #ff5c00; } .colab-callout-footer > a > span { vertical-align: middle; color: #37474f; padding-left: 10px; font-size: 14px; } .colab-callout-row devsite-code[dark-code] + .colab-callout-footer > a > span { color: #fff; } a.colab-button { background: rgba(255, 255, 255, .75); border: solid 1px rgba(0, 0, 0, .08); border-bottom-color: rgba(0, 0, 0, .15); border-radius: 4px; color: #aaa; display: inline-block; font-size: 11px !important; font-weight: 300; line-height: 16px; padding: 4px 8px; text-decoration: none; text-transform: uppercase; } a.colab-button:hover { background: white; border-color: rgba(0, 0, 0, .2); color: #666; } a.colab-button span { background: url(/images/colab_logo_button.svg) no-repeat 1px 1px / 20px; border-radius: 4px; display: inline-block; padding-left: 24px; text-decoration: none; } @media screen and (max-width: 600px) { .tfo-notebook-buttons td { display: block; } } /* guide and tutorials landing page cards and sections */ .tfo-landing-page-card { padding: 16px; box-shadow: 0 0 36px rgba(0,0,0,0.1); border-radius: 10px; } /* Page section headings */ .tfo-landing-page-heading h2, h2.tfo-landing-page-heading { font-family: "Google Sans", sans-serif; color: #425066; font-size: 30px; font-weight: 700; line-height: 40px; } /* Item title headings */ .tfo-landing-page-heading h3, h3.tfo-landing-page-heading, .tfo-landing-page-card h3, h3.tfo-landing-page-card { font-family: "Google Sans", sans-serif; color: #425066; font-size: 20px; font-weight: 500; line-height: 26px; } /* List of tutorials notebooks for subsites */ .tfo-landing-page-resources-ul { padding-left: 15px } .tfo-landing-page-resources-ul > li { margin: 6px 0; } /* Temporary fix to hide product description in header on landing pages */ devsite-header .devsite-product-description { display: none; } </style> <div class="devsite-article-meta nocontent" role="navigation"> <ul class="devsite-breadcrumb-list" aria-label="Breadcrumb"> <li class="devsite-breadcrumb-item "> <a href="https://www.tensorflow.org/" class="devsite-breadcrumb-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Breadcrumbs" data-value="1" track-type="globalNav" track-name="breadcrumb" track-metadata-position="1" track-metadata-eventdetail="TensorFlow" > TensorFlow </a> </li> <li class="devsite-breadcrumb-item "> <div class="devsite-breadcrumb-guillemet material-icons" aria-hidden="true"></div> <a href="https://www.tensorflow.org/resources" class="devsite-breadcrumb-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Breadcrumbs" data-value="2" track-type="globalNav" track-name="breadcrumb" track-metadata-position="2" track-metadata-eventdetail="" > Resources </a> </li> <li class="devsite-breadcrumb-item "> <div class="devsite-breadcrumb-guillemet material-icons" aria-hidden="true"></div> <a href="https://www.tensorflow.org/neural_structured_learning" class="devsite-breadcrumb-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Breadcrumbs" data-value="3" track-type="globalNav" track-name="breadcrumb" track-metadata-position="3" track-metadata-eventdetail="Neural Structured Learning" > Neural Structured Learning </a> </li> <li class="devsite-breadcrumb-item "> <div class="devsite-breadcrumb-guillemet material-icons" aria-hidden="true"></div> <a href="https://www.tensorflow.org/neural_structured_learning/framework" class="devsite-breadcrumb-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Breadcrumbs" data-value="4" track-type="globalNav" track-name="breadcrumb" track-metadata-position="4" track-metadata-eventdetail="" > Guide & Tutorials </a> </li> </ul> <devsite-thumb-rating position="header"> </devsite-thumb-rating> </div> <h1 class="devsite-page-title" tabindex="-1"> Graph regularization for sentiment classification using synthesized graphs </h1> <devsite-feature-tooltip ack-key="AckCollectionsBookmarkTooltipDismiss" analytics-category="Site-Wide Custom Events" analytics-action-show="Callout Profile displayed" analytics-action-close="Callout Profile dismissed" analytics-label="Create Collection Callout" class="devsite-page-bookmark-tooltip nocontent" dismiss-button="true" id="devsite-collections-dropdown" dismiss-button-text="Dismiss" close-button-text="Got it"> <devsite-bookmark></devsite-bookmark> <span slot="popout-heading"> Stay organized with collections </span> <span slot="popout-contents"> Save and categorize content based on your preferences. </span> </devsite-feature-tooltip> <div class="devsite-page-title-meta"><devsite-view-release-notes></devsite-view-release-notes></div> <devsite-toc class="devsite-nav" depth="2" devsite-toc-embedded > </devsite-toc> <div class="devsite-article-body clearfix "> <p></p> <!-- DO NOT EDIT! Automatically generated file. --> <div itemscope itemtype="http://developers.google.com/ReferenceObject"> <meta itemprop="name" content="Graph regularization for sentiment classification using synthesized graphs" /> <meta itemprop="path" content="Guide & Tutorials" /> <meta itemprop="property" content="nsl.configs.GraphBuilderConfig"/> <meta itemprop="property" content="nsl.configs.make_graph_reg_config"/> <meta itemprop="property" content="nsl.keras.GraphRegularization"/> <meta itemprop="property" content="nsl.tools.build_graph_from_config"/> <meta itemprop="property" content="nsl.tools.pack_nbrs"/> </div> <table class="tfo-notebook-buttons" align="left"> <td> <a target="_blank" href="https://www.tensorflow.org/neural_structured_learning/tutorials/graph_keras_lstm_imdb"><img src="https://www.tensorflow.org/images/tf_logo_32px.png">View on TensorFlow.org</a> </td> <td> <a target="_blank" href="https://colab.research.google.com/github/tensorflow/neural-structured-learning/blob/master/g3doc/tutorials/graph_keras_lstm_imdb.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png">Run in Google Colab</a> </td> <td> <a target="_blank" href="https://github.com/tensorflow/neural-structured-learning/blob/master/g3doc/tutorials/graph_keras_lstm_imdb.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png">View source on GitHub</a> </td> <td> <a href="https://storage.googleapis.com/tensorflow_docs/neural-structured-learning/g3doc/tutorials/graph_keras_lstm_imdb.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png">Download notebook</a> </td> <td> <a href="https://tfhub.dev/"><img src="https://www.tensorflow.org/images/hub_logo_32px.png">See TF Hub model</a> </td> </table> <h2 id="overview" data-text="Overview" tabindex="-1">Overview</h2> <p>This notebook classifies movie reviews as <em>positive</em> or <em>negative</em> using the text of the review. This is an example of <em>binary</em> classification, an important and widely applicable kind of machine learning problem.</p> <p>We will demonstrate the use of graph regularization in this notebook by building a graph from the given input. The general recipe for building a graph-regularized model using the Neural Structured Learning (NSL) framework when the input does not contain an explicit graph is as follows:</p> <ol> <li>Create embeddings for each text sample in the input. This can be done using pre-trained models such as <a href="https://arxiv.org/pdf/1310.4546.pdf">word2vec</a>, <a href="https://arxiv.org/abs/1602.02215">Swivel</a>, <a href="https://arxiv.org/abs/1810.04805">BERT</a> etc.</li> <li>Build a graph based on these embeddings by using a similarity metric such as the 'L2' distance, 'cosine' distance, etc. Nodes in the graph correspond to samples and edges in the graph correspond to similarity between pairs of samples.</li> <li>Generate training data from the above synthesized graph and sample features. The resulting training data will contain neighbor features in addition to the original node features.</li> <li>Create a neural network as a base model using the Keras sequential, functional, or subclass API.</li> <li>Wrap the base model with the GraphRegularization wrapper class, which is provided by the NSL framework, to create a new graph Keras model. This new model will include a graph regularization loss as the regularization term in its training objective.</li> <li>Train and evaluate the graph Keras model.</li> </ol> <aside class="note"><strong>Note:</strong><span> We expect that it would take readers about 1 hour to go through this tutorial.</span></aside> <h2 id="requirements" data-text="Requirements" tabindex="-1">Requirements</h2> <ol> <li>Install the Neural Structured Learning package.</li> <li>Install tensorflow-hub.</li> </ol> <div></div><devsite-code><pre class="devsite-click-to-copy" translate="no" dir="ltr" is-upgraded syntax="Bash"><code class='devsite-terminal' translate="no" dir="ltr">pip<span class="devsite-syntax-w"> </span>install<span class="devsite-syntax-w"> </span>--quiet<span class="devsite-syntax-w"> </span>neural-structured-learning</code> <code class='devsite-terminal' translate="no" dir="ltr">pip<span class="devsite-syntax-w"> </span>install<span class="devsite-syntax-w"> </span>--quiet<span class="devsite-syntax-w"> </span>tensorflow-hub</code></pre></devsite-code> <h2 id="dependencies_and_imports" data-text="Dependencies and imports" tabindex="-1">Dependencies and imports</h2> <div></div><devsite-code><pre class="devsite-click-to-copy" translate="no" dir="ltr" is-upgraded syntax="Python"><code translate="no" dir="ltr"><span class="devsite-syntax-kn">import</span> <span class="devsite-syntax-nn">matplotlib.pyplot</span> <span class="devsite-syntax-k">as</span> <span class="devsite-syntax-nn">plt</span> <span class="devsite-syntax-kn">import</span> <span class="devsite-syntax-nn">numpy</span> <span class="devsite-syntax-k">as</span> <span class="devsite-syntax-nn">np</span> <span class="devsite-syntax-kn">import</span> <span class="devsite-syntax-nn">neural_structured_learning</span> <span class="devsite-syntax-k">as</span> <span class="devsite-syntax-nn">nsl</span> <span class="devsite-syntax-kn">import</span> <span class="devsite-syntax-nn">tensorflow</span> <span class="devsite-syntax-k">as</span> <span class="devsite-syntax-nn">tf</span> <span class="devsite-syntax-kn">import</span> <span class="devsite-syntax-nn">tensorflow_hub</span> <span class="devsite-syntax-k">as</span> <span class="devsite-syntax-nn">hub</span> <span class="devsite-syntax-c1"># Resets notebook state</span> <span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">keras</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">backend</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">clear_session</span><span class="devsite-syntax-p">()</span> <span class="devsite-syntax-nb">print</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-s2">"Version: "</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">__version__</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-nb">print</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-s2">"Eager mode: "</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">executing_eagerly</span><span class="devsite-syntax-p">())</span> <span class="devsite-syntax-nb">print</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-s2">"Hub version: "</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">hub</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">__version__</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-nb">print</span><span class="devsite-syntax-p">(</span> <span class="devsite-syntax-s2">"GPU is"</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-s2">"available"</span> <span class="devsite-syntax-k">if</span> <span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">config</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">list_physical_devices</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-s2">"GPU"</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-k">else</span> <span class="devsite-syntax-s2">"NOT AVAILABLE"</span><span class="devsite-syntax-p">)</span> </code></pre></devsite-code> <div></div><devsite-code><pre class="tfo-notebook-code-cell-output" translate="no" dir="ltr" is-upgraded> 2022-12-14 12:19:13.551836: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory 2022-12-14 12:19:13.551949: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory 2022-12-14 12:19:13.551962: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly. Version: 2.11.0 Eager mode: True Hub version: 0.12.0 GPU is NOT AVAILABLE 2022-12-14 12:19:14.770677: E tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:267] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected </pre></devsite-code> <h2 id="imdb_dataset" data-text="IMDB dataset" tabindex="-1">IMDB dataset</h2> <p>The <a href="https://www.tensorflow.org/api_docs/python/tf/keras/datasets/imdb">IMDB dataset</a> contains the text of 50,000 movie reviews from the <a href="https://www.imdb.com/">Internet Movie Database</a>. These are split into 25,000 reviews for training and 25,000 reviews for testing. The training and testing sets are <em>balanced</em>, meaning they contain an equal number of positive and negative reviews.</p> <p>In this tutorial, we will use a preprocessed version of the IMDB dataset.</p> <h3 id="download_preprocessed_imdb_dataset" data-text="Download preprocessed IMDB dataset" tabindex="-1">Download preprocessed IMDB dataset</h3> <p>The IMDB dataset comes packaged with TensorFlow. It has already been preprocessed such that the reviews (sequences of words) have been converted to sequences of integers, where each integer represents a specific word in a dictionary.</p> <p>The following code downloads the IMDB dataset (or uses a cached copy if it has already been downloaded):</p> <div></div><devsite-code><pre class="devsite-click-to-copy" translate="no" dir="ltr" is-upgraded syntax="Python"><code translate="no" dir="ltr"><span class="devsite-syntax-n">imdb</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">keras</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">datasets</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">imdb</span> <span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">pp_train_data</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">pp_train_labels</span><span class="devsite-syntax-p">),</span> <span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">pp_test_data</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">pp_test_labels</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-p">(</span> <span class="devsite-syntax-n">imdb</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">load_data</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">num_words</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-mi">10000</span><span class="devsite-syntax-p">))</span> </code></pre></devsite-code> <div></div><devsite-code><pre class="tfo-notebook-code-cell-output" translate="no" dir="ltr" is-upgraded> Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/imdb.npz 17464789/17464789 [==============================] - 0s 0us/step </pre></devsite-code> <p>The argument <code translate="no" dir="ltr">num_words=10000</code> keeps the top 10,000 most frequently occurring words in the training data. The rare words are discarded to keep the size of the vocabulary manageable.</p> <h3 id="explore_the_data" data-text="Explore the data" tabindex="-1">Explore the data</h3> <p>Let's take a moment to understand the format of the data. The dataset comes preprocessed: each example is an array of integers representing the words of the movie review. Each label is an integer value of either 0 or 1, where 0 is a negative review, and 1 is a positive review.</p> <div></div><devsite-code><pre class="devsite-click-to-copy" translate="no" dir="ltr" is-upgraded syntax="Python"><code translate="no" dir="ltr"><span class="devsite-syntax-nb">print</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-s1">'Training entries: </span><span class="devsite-syntax-si">{}</span><span class="devsite-syntax-s1">, labels: </span><span class="devsite-syntax-si">{}</span><span class="devsite-syntax-s1">'</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">format</span><span class="devsite-syntax-p">(</span> <span class="devsite-syntax-nb">len</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">pp_train_data</span><span class="devsite-syntax-p">),</span> <span class="devsite-syntax-nb">len</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">pp_train_labels</span><span class="devsite-syntax-p">)))</span> <span class="devsite-syntax-n">training_samples_count</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-nb">len</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">pp_train_data</span><span class="devsite-syntax-p">)</span> </code></pre></devsite-code> <div></div><devsite-code><pre class="tfo-notebook-code-cell-output" translate="no" dir="ltr" is-upgraded> Training entries: 25000, labels: 25000 </pre></devsite-code> <p>The text of reviews have been converted to integers, where each integer represents a specific word in a dictionary. Here's what the first review looks like:</p> <div></div><devsite-code><pre class="devsite-click-to-copy" translate="no" dir="ltr" is-upgraded syntax="Python"><code translate="no" dir="ltr"><span class="devsite-syntax-nb">print</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">pp_train_data</span><span class="devsite-syntax-p">[</span><span class="devsite-syntax-mi">0</span><span class="devsite-syntax-p">])</span> </code></pre></devsite-code> <div></div><devsite-code><pre class="tfo-notebook-code-cell-output" translate="no" dir="ltr" is-upgraded> [1, 14, 22, 16, 43, 530, 973, 1622, 1385, 65, 458, 4468, 66, 3941, 4, 173, 36, 256, 5, 25, 100, 43, 838, 112, 50, 670, 2, 9, 35, 480, 284, 5, 150, 4, 172, 112, 167, 2, 336, 385, 39, 4, 172, 4536, 1111, 17, 546, 38, 13, 447, 4, 192, 50, 16, 6, 147, 2025, 19, 14, 22, 4, 1920, 4613, 469, 4, 22, 71, 87, 12, 16, 43, 530, 38, 76, 15, 13, 1247, 4, 22, 17, 515, 17, 12, 16, 626, 18, 2, 5, 62, 386, 12, 8, 316, 8, 106, 5, 4, 2223, 5244, 16, 480, 66, 3785, 33, 4, 130, 12, 16, 38, 619, 5, 25, 124, 51, 36, 135, 48, 25, 1415, 33, 6, 22, 12, 215, 28, 77, 52, 5, 14, 407, 16, 82, 2, 8, 4, 107, 117, 5952, 15, 256, 4, 2, 7, 3766, 5, 723, 36, 71, 43, 530, 476, 26, 400, 317, 46, 7, 4, 2, 1029, 13, 104, 88, 4, 381, 15, 297, 98, 32, 2071, 56, 26, 141, 6, 194, 7486, 18, 4, 226, 22, 21, 134, 476, 26, 480, 5, 144, 30, 5535, 18, 51, 36, 28, 224, 92, 25, 104, 4, 226, 65, 16, 38, 1334, 88, 12, 16, 283, 5, 16, 4472, 113, 103, 32, 15, 16, 5345, 19, 178, 32] </pre></devsite-code> <p>Movie reviews may be different lengths. The below code shows the number of words in the first and second reviews. Since inputs to a neural network must be the same length, we'll need to resolve this later.</p> <div></div><devsite-code><pre class="devsite-click-to-copy" translate="no" dir="ltr" is-upgraded syntax="Python"><code translate="no" dir="ltr"><span class="devsite-syntax-nb">len</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">pp_train_data</span><span class="devsite-syntax-p">[</span><span class="devsite-syntax-mi">0</span><span class="devsite-syntax-p">]),</span> <span class="devsite-syntax-nb">len</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">pp_train_data</span><span class="devsite-syntax-p">[</span><span class="devsite-syntax-mi">1</span><span class="devsite-syntax-p">])</span> </code></pre></devsite-code> <div></div><devsite-code><pre class="tfo-notebook-code-cell-output" translate="no" dir="ltr" is-upgraded> (218, 189) </pre></devsite-code> <h3 id="convert_the_integers_back_to_words" data-text="Convert the integers back to words" tabindex="-1">Convert the integers back to words</h3> <p>It may be useful to know how to convert integers back to the corresponding text. Here, we'll create a helper function to query a dictionary object that contains the integer to string mapping:</p> <div></div><devsite-code><pre class="devsite-click-to-copy" translate="no" dir="ltr" is-upgraded syntax="Python"><code translate="no" dir="ltr"><span class="devsite-syntax-k">def</span> <span class="devsite-syntax-nf">build_reverse_word_index</span><span class="devsite-syntax-p">():</span> <span class="devsite-syntax-c1"># A dictionary mapping words to an integer index</span> <span class="devsite-syntax-n">word_index</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">imdb</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">get_word_index</span><span class="devsite-syntax-p">()</span> <span class="devsite-syntax-c1"># The first indices are reserved</span> <span class="devsite-syntax-n">word_index</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-p">{</span><span class="devsite-syntax-n">k</span><span class="devsite-syntax-p">:</span> <span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">v</span> <span class="devsite-syntax-o">+</span> <span class="devsite-syntax-mi">3</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-k">for</span> <span class="devsite-syntax-n">k</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">v</span> <span class="devsite-syntax-ow">in</span> <span class="devsite-syntax-n">word_index</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">items</span><span class="devsite-syntax-p">()}</span> <span class="devsite-syntax-n">word_index</span><span class="devsite-syntax-p">[</span><span class="devsite-syntax-s1">'<PAD>'</span><span class="devsite-syntax-p">]</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-mi">0</span> <span class="devsite-syntax-n">word_index</span><span class="devsite-syntax-p">[</span><span class="devsite-syntax-s1">'<START>'</span><span class="devsite-syntax-p">]</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-mi">1</span> <span class="devsite-syntax-n">word_index</span><span class="devsite-syntax-p">[</span><span class="devsite-syntax-s1">'<UNK>'</span><span class="devsite-syntax-p">]</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-mi">2</span> <span class="devsite-syntax-c1"># unknown</span> <span class="devsite-syntax-n">word_index</span><span class="devsite-syntax-p">[</span><span class="devsite-syntax-s1">'<UNUSED>'</span><span class="devsite-syntax-p">]</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-mi">3</span> <span class="devsite-syntax-k">return</span> <span class="devsite-syntax-nb">dict</span><span class="devsite-syntax-p">((</span><span class="devsite-syntax-n">value</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">key</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-k">for</span> <span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">key</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">value</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-ow">in</span> <span class="devsite-syntax-n">word_index</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">items</span><span class="devsite-syntax-p">())</span> <span class="devsite-syntax-n">reverse_word_index</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">build_reverse_word_index</span><span class="devsite-syntax-p">()</span> <span class="devsite-syntax-k">def</span> <span class="devsite-syntax-nf">decode_review</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">text</span><span class="devsite-syntax-p">):</span> <span class="devsite-syntax-k">return</span> <span class="devsite-syntax-s1">' '</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">join</span><span class="devsite-syntax-p">([</span><span class="devsite-syntax-n">reverse_word_index</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">get</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">i</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-s1">'?'</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-k">for</span> <span class="devsite-syntax-n">i</span> <span class="devsite-syntax-ow">in</span> <span class="devsite-syntax-n">text</span><span class="devsite-syntax-p">])</span> </code></pre></devsite-code> <div></div><devsite-code><pre class="tfo-notebook-code-cell-output" translate="no" dir="ltr" is-upgraded> Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/imdb_word_index.json 1641221/1641221 [==============================] - 0s 0us/step </pre></devsite-code> <p>Now we can use the <code translate="no" dir="ltr">decode_review</code> function to display the text for the first review:</p> <div></div><devsite-code><pre class="devsite-click-to-copy" translate="no" dir="ltr" is-upgraded syntax="Python"><code translate="no" dir="ltr"><span class="devsite-syntax-n">decode_review</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">pp_train_data</span><span class="devsite-syntax-p">[</span><span class="devsite-syntax-mi">0</span><span class="devsite-syntax-p">])</span> </code></pre></devsite-code> <div></div><devsite-code><pre class="tfo-notebook-code-cell-output" translate="no" dir="ltr" is-upgraded> "<START> this film was just brilliant casting location scenery story direction everyone's really suited the part they played and you could just imagine being there robert <UNK> is an amazing actor and now the same being director <UNK> father came from the same scottish island as myself so i loved the fact there was a real connection with this film the witty remarks throughout the film were great it was just brilliant so much that i bought the film as soon as it was released for <UNK> and would recommend it to everyone to watch and the fly fishing was amazing really cried at the end it was so sad and you know what they say if you cry at a film it must have been good and this definitely was also <UNK> to the two little boy's that played the <UNK> of norman and paul they were just brilliant children are often left out of the <UNK> list i think because the stars that play them all grown up are such a big profile for the whole film but these children are amazing and should be praised for what they have done don't you think the whole story was so lovely because it was true and was someone's life after all that was shared with us all" </pre></devsite-code> <h2 id="graph_construction" data-text="Graph construction" tabindex="-1">Graph construction</h2> <p>Graph construction involves creating embeddings for text samples and then using a similarity function to compare the embeddings.</p> <p>Before proceeding further, we first create a directory to store artifacts created by this tutorial.</p> <div></div><devsite-code><pre class="devsite-click-to-copy" translate="no" dir="ltr" is-upgraded syntax="Bash"><code class='devsite-terminal' translate="no" dir="ltr">mkdir<span class="devsite-syntax-w"> </span>-p<span class="devsite-syntax-w"> </span>/tmp/imdb</code></pre></devsite-code> <h3 id="create_sample_embeddings" data-text="Create sample embeddings" tabindex="-1">Create sample embeddings</h3> <p>We will use pretrained Swivel embeddings to create embeddings in the <a href="https://www.tensorflow.org/api_docs/python/tf/train/Example"><code translate="no" dir="ltr">tf.train.Example</code></a> format for each sample in the input. We will store the resulting embeddings in the <code translate="no" dir="ltr">TFRecord</code> format along with an additional feature that represents the ID of each sample. This is important and will allow us match sample embeddings with corresponding nodes in the graph later.</p> <div></div><devsite-code><pre class="devsite-click-to-copy" translate="no" dir="ltr" is-upgraded syntax="Python"><code translate="no" dir="ltr"><span class="devsite-syntax-n">pretrained_embedding</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-s1">'https://tfhub.dev/google/tf2-preview/gnews-swivel-20dim/1'</span> <span class="devsite-syntax-n">hub_layer</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">hub</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">KerasLayer</span><span class="devsite-syntax-p">(</span> <span class="devsite-syntax-n">pretrained_embedding</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">input_shape</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-p">[],</span> <span class="devsite-syntax-n">dtype</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">string</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">trainable</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-kc">True</span><span class="devsite-syntax-p">)</span> </code></pre></devsite-code> <div></div><devsite-code><pre class="tfo-notebook-code-cell-output" translate="no" dir="ltr" is-upgraded> WARNING:tensorflow:Please fix your imports. Module tensorflow.python.training.tracking.data_structures has been moved to tensorflow.python.trackable.data_structures. The old module will be deleted in version 2.11. </pre></devsite-code> <div></div><devsite-code><pre class="devsite-click-to-copy" translate="no" dir="ltr" is-upgraded syntax="Python"><code translate="no" dir="ltr"><span class="devsite-syntax-k">def</span> <span class="devsite-syntax-nf">_int64_feature</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">value</span><span class="devsite-syntax-p">):</span> <span class="devsite-syntax-w"> </span><span class="devsite-syntax-sd">"""Returns int64 tf.train.Feature."""</span> <span class="devsite-syntax-k">return</span> <span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">train</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">Feature</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">int64_list</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">train</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">Int64List</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">value</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-n">value</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">tolist</span><span class="devsite-syntax-p">()))</span> <span class="devsite-syntax-k">def</span> <span class="devsite-syntax-nf">_bytes_feature</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">value</span><span class="devsite-syntax-p">):</span> <span class="devsite-syntax-w"> </span><span class="devsite-syntax-sd">"""Returns bytes tf.train.Feature."""</span> <span class="devsite-syntax-k">return</span> <span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">train</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">Feature</span><span class="devsite-syntax-p">(</span> <span class="devsite-syntax-n">bytes_list</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">train</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">BytesList</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">value</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-p">[</span><span class="devsite-syntax-n">value</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">encode</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-s1">'utf-8'</span><span class="devsite-syntax-p">)]))</span> <span class="devsite-syntax-k">def</span> <span class="devsite-syntax-nf">_float_feature</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">value</span><span class="devsite-syntax-p">):</span> <span class="devsite-syntax-w"> </span><span class="devsite-syntax-sd">"""Returns float tf.train.Feature."""</span> <span class="devsite-syntax-k">return</span> <span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">train</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">Feature</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">float_list</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">train</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">FloatList</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">value</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-n">value</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">tolist</span><span class="devsite-syntax-p">()))</span> <span class="devsite-syntax-k">def</span> <span class="devsite-syntax-nf">create_embedding_example</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">word_vector</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">record_id</span><span class="devsite-syntax-p">):</span> <span class="devsite-syntax-w"> </span><span class="devsite-syntax-sd">"""Create tf.Example containing the sample's embedding and its ID."""</span> <span class="devsite-syntax-n">text</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">decode_review</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">word_vector</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-c1"># Shape = [batch_size,].</span> <span class="devsite-syntax-n">sentence_embedding</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">hub_layer</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">reshape</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">text</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">shape</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-p">[</span><span class="devsite-syntax-o">-</span><span class="devsite-syntax-mi">1</span><span class="devsite-syntax-p">,]))</span> <span class="devsite-syntax-c1"># Flatten the sentence embedding back to 1-D.</span> <span class="devsite-syntax-n">sentence_embedding</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">reshape</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">sentence_embedding</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">shape</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-p">[</span><span class="devsite-syntax-o">-</span><span class="devsite-syntax-mi">1</span><span class="devsite-syntax-p">])</span> <span class="devsite-syntax-n">features</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-p">{</span> <span class="devsite-syntax-s1">'id'</span><span class="devsite-syntax-p">:</span> <span class="devsite-syntax-n">_bytes_feature</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-nb">str</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">record_id</span><span class="devsite-syntax-p">)),</span> <span class="devsite-syntax-s1">'embedding'</span><span class="devsite-syntax-p">:</span> <span class="devsite-syntax-n">_float_feature</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">sentence_embedding</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">numpy</span><span class="devsite-syntax-p">())</span> <span class="devsite-syntax-p">}</span> <span class="devsite-syntax-k">return</span> <span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">train</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">Example</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">features</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">train</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">Features</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">feature</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-n">features</span><span class="devsite-syntax-p">))</span> <span class="devsite-syntax-k">def</span> <span class="devsite-syntax-nf">create_embeddings</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">word_vectors</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">output_path</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">starting_record_id</span><span class="devsite-syntax-p">):</span> <span class="devsite-syntax-n">record_id</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-nb">int</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">starting_record_id</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-k">with</span> <span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">io</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">TFRecordWriter</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">output_path</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-k">as</span> <span class="devsite-syntax-n">writer</span><span class="devsite-syntax-p">:</span> <span class="devsite-syntax-k">for</span> <span class="devsite-syntax-n">word_vector</span> <span class="devsite-syntax-ow">in</span> <span class="devsite-syntax-n">word_vectors</span><span class="devsite-syntax-p">:</span> <span class="devsite-syntax-n">example</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">create_embedding_example</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">word_vector</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">record_id</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">record_id</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">record_id</span> <span class="devsite-syntax-o">+</span> <span class="devsite-syntax-mi">1</span> <span class="devsite-syntax-n">writer</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">write</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">example</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">SerializeToString</span><span class="devsite-syntax-p">())</span> <span class="devsite-syntax-k">return</span> <span class="devsite-syntax-n">record_id</span> <span class="devsite-syntax-c1"># Persist TF.Example features containing embeddings for training data in</span> <span class="devsite-syntax-c1"># TFRecord format.</span> <span class="devsite-syntax-n">create_embeddings</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">pp_train_data</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-s1">'/tmp/imdb/embeddings.tfr'</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-mi">0</span><span class="devsite-syntax-p">)</span> </code></pre></devsite-code> <div></div><devsite-code><pre class="tfo-notebook-code-cell-output" translate="no" dir="ltr" is-upgraded> 25000 </pre></devsite-code> <h3 id="build_a_graph" data-text="Build a graph" tabindex="-1">Build a graph</h3> <p>Now that we have the sample embeddings, we will use them to build a similarity graph, i.e, nodes in this graph will correspond to samples and edges in this graph will correspond to similarity between pairs of nodes.</p> <p>Neural Structured Learning provides a graph building library to build a graph based on sample embeddings. It uses <a href="https://en.wikipedia.org/wiki/Cosine_similarity"><strong>cosine similarity</strong></a> as the similarity measure to compare embeddings and build edges between them. It also allows us to specify a similarity threshold, which can be used to discard dissimilar edges from the final graph. In this example, using 0.99 as the similarity threshold and 12345 as the random seed, we end up with a graph that has 429,415 bi-directional edges. Here we're using the graph builder's support for <a href="https://en.wikipedia.org/wiki/Locality-sensitive_hashing">locality-sensitive hashing</a> (LSH) to speed up graph building. For details on using the graph builder's LSH support, see the <a href="https://www.tensorflow.org/neural_structured_learning/api_docs/python/nsl/tools/build_graph_from_config"><code translate="no" dir="ltr">build_graph_from_config</code></a> API documentation.</p> <div></div><devsite-code><pre class="devsite-click-to-copy" translate="no" dir="ltr" is-upgraded syntax="Python"><code translate="no" dir="ltr"><span class="devsite-syntax-n">graph_builder_config</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">nsl</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">configs</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">GraphBuilderConfig</span><span class="devsite-syntax-p">(</span> <span class="devsite-syntax-n">similarity_threshold</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-mf">0.99</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">lsh_splits</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-mi">32</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">lsh_rounds</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-mi">15</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">random_seed</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-mi">12345</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">nsl</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">tools</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">build_graph_from_config</span><span class="devsite-syntax-p">([</span><span class="devsite-syntax-s1">'/tmp/imdb/embeddings.tfr'</span><span class="devsite-syntax-p">],</span> <span class="devsite-syntax-s1">'/tmp/imdb/graph_99.tsv'</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">graph_builder_config</span><span class="devsite-syntax-p">)</span> </code></pre></devsite-code> <p>Each bi-directional edge is represented by two directed edges in the output TSV file, so that file contains 429,415 * 2 = 858,830 total lines:</p> <div></div><devsite-code><pre class="devsite-click-to-copy" translate="no" dir="ltr" is-upgraded syntax="Bash"><code class='devsite-terminal' translate="no" dir="ltr">wc<span class="devsite-syntax-w"> </span>-l<span class="devsite-syntax-w"> </span>/tmp/imdb/graph_99.tsv</code></pre></devsite-code> <div></div><devsite-code><pre class="tfo-notebook-code-cell-output" translate="no" dir="ltr" is-upgraded> 858830 /tmp/imdb/graph_99.tsv </pre></devsite-code> <aside class="note"><strong>Note:</strong><span> Graph quality and by extension, embedding quality, are very important for graph regularization. While we have used Swivel embeddings in this notebook, using BERT embeddings for instance, will likely capture review semantics more accurately. We encourage users to use embeddings of their choice and as appropriate to their needs.</span></aside> <h2 id="sample_features" data-text="Sample features" tabindex="-1">Sample features</h2> <p>We create sample features for our problem using the <a href="https://www.tensorflow.org/api_docs/python/tf/train/Example"><code translate="no" dir="ltr">tf.train.Example</code></a> format and persist them in the <code translate="no" dir="ltr">TFRecord</code> format. Each sample will include the following three features:</p> <ol> <li><strong>id</strong>: The node ID of the sample.</li> <li><strong>words</strong>: An int64 list containing word IDs.</li> <li><strong>label</strong>: A singleton int64 identifying the target class of the review.</li> </ol> <div></div><devsite-code><pre class="devsite-click-to-copy" translate="no" dir="ltr" is-upgraded syntax="Python"><code translate="no" dir="ltr"><span class="devsite-syntax-k">def</span> <span class="devsite-syntax-nf">create_example</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">word_vector</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">label</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">record_id</span><span class="devsite-syntax-p">):</span> <span class="devsite-syntax-w"> </span><span class="devsite-syntax-sd">"""Create tf.Example containing the sample's word vector, label, and ID."""</span> <span class="devsite-syntax-n">features</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-p">{</span> <span class="devsite-syntax-s1">'id'</span><span class="devsite-syntax-p">:</span> <span class="devsite-syntax-n">_bytes_feature</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-nb">str</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">record_id</span><span class="devsite-syntax-p">)),</span> <span class="devsite-syntax-s1">'words'</span><span class="devsite-syntax-p">:</span> <span class="devsite-syntax-n">_int64_feature</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">np</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">asarray</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">word_vector</span><span class="devsite-syntax-p">)),</span> <span class="devsite-syntax-s1">'label'</span><span class="devsite-syntax-p">:</span> <span class="devsite-syntax-n">_int64_feature</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">np</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">asarray</span><span class="devsite-syntax-p">([</span><span class="devsite-syntax-n">label</span><span class="devsite-syntax-p">])),</span> <span class="devsite-syntax-p">}</span> <span class="devsite-syntax-k">return</span> <span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">train</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">Example</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">features</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">train</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">Features</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">feature</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-n">features</span><span class="devsite-syntax-p">))</span> <span class="devsite-syntax-k">def</span> <span class="devsite-syntax-nf">create_records</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">word_vectors</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">labels</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">record_path</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">starting_record_id</span><span class="devsite-syntax-p">):</span> <span class="devsite-syntax-n">record_id</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-nb">int</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">starting_record_id</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-k">with</span> <span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">io</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">TFRecordWriter</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">record_path</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-k">as</span> <span class="devsite-syntax-n">writer</span><span class="devsite-syntax-p">:</span> <span class="devsite-syntax-k">for</span> <span class="devsite-syntax-n">word_vector</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">label</span> <span class="devsite-syntax-ow">in</span> <span class="devsite-syntax-nb">zip</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">word_vectors</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">labels</span><span class="devsite-syntax-p">):</span> <span class="devsite-syntax-n">example</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">create_example</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">word_vector</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">label</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">record_id</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">record_id</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">record_id</span> <span class="devsite-syntax-o">+</span> <span class="devsite-syntax-mi">1</span> <span class="devsite-syntax-n">writer</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">write</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">example</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">SerializeToString</span><span class="devsite-syntax-p">())</span> <span class="devsite-syntax-k">return</span> <span class="devsite-syntax-n">record_id</span> <span class="devsite-syntax-c1"># Persist TF.Example features (word vectors and labels) for training and test</span> <span class="devsite-syntax-c1"># data in TFRecord format.</span> <span class="devsite-syntax-n">next_record_id</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">create_records</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">pp_train_data</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">pp_train_labels</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-s1">'/tmp/imdb/train_data.tfr'</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-mi">0</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">create_records</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">pp_test_data</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">pp_test_labels</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-s1">'/tmp/imdb/test_data.tfr'</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">next_record_id</span><span class="devsite-syntax-p">)</span> </code></pre></devsite-code> <div></div><devsite-code><pre class="tfo-notebook-code-cell-output" translate="no" dir="ltr" is-upgraded> 50000 </pre></devsite-code> <h2 id="augment_training_data_with_graph_neighbors" data-text="Augment training data with graph neighbors" tabindex="-1">Augment training data with graph neighbors</h2> <p>Since we have the sample features and the synthesized graph, we can generate the augmented training data for Neural Structured Learning. The NSL framework provides a library to combine the graph and the sample features to produce the final training data for graph regularization. The resulting training data will include original sample features as well as features of their corresponding neighbors.</p> <p>In this tutorial, we consider undirected edges and use a maximum of 3 neighbors per sample to augment training data with graph neighbors.</p> <div></div><devsite-code><pre class="devsite-click-to-copy" translate="no" dir="ltr" is-upgraded syntax="Python"><code translate="no" dir="ltr"><span class="devsite-syntax-n">nsl</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">tools</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">pack_nbrs</span><span class="devsite-syntax-p">(</span> <span class="devsite-syntax-s1">'/tmp/imdb/train_data.tfr'</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-s1">''</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-s1">'/tmp/imdb/graph_99.tsv'</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-s1">'/tmp/imdb/nsl_train_data.tfr'</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">add_undirected_edges</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-kc">True</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">max_nbrs</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-mi">3</span><span class="devsite-syntax-p">)</span> </code></pre></devsite-code> <h2 id="base_model" data-text="Base model" tabindex="-1">Base model</h2> <p>We are now ready to build a base model without graph regularization. In order to build this model, we can either use embeddings that were used in building the graph, or we can learn new embeddings jointly along with the classification task. For the purpose of this notebook, we will do the latter.</p> <h3 id="global_variables" data-text="Global variables" tabindex="-1">Global variables</h3> <div></div><devsite-code><pre class="devsite-click-to-copy" translate="no" dir="ltr" is-upgraded syntax="Python"><code translate="no" dir="ltr"><span class="devsite-syntax-n">NBR_FEATURE_PREFIX</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-s1">'NL_nbr_'</span> <span class="devsite-syntax-n">NBR_WEIGHT_SUFFIX</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-s1">'_weight'</span> </code></pre></devsite-code> <h3 id="hyperparameters" data-text="Hyperparameters" tabindex="-1">Hyperparameters</h3> <p>We will use an instance of <code translate="no" dir="ltr">HParams</code> to inclue various hyperparameters and constants used for training and evaluation. We briefly describe each of them below:</p> <ul> <li><p><strong>num_classes</strong>: There are 2 classes -- <em>positive</em> and <em>negative</em>.</p></li> <li><p><strong>max_seq_length</strong>: This is the maximum number of words considered from each movie review in this example.</p></li> <li><p><strong>vocab_size</strong>: This is the size of the vocabulary considered for this example.</p></li> <li><p><strong>distance_type</strong>: This is the distance metric used to regularize the sample with its neighbors.</p></li> <li><p><strong>graph_regularization_multiplier</strong>: This controls the relative weight of the graph regularization term in the overall loss function.</p></li> <li><p><strong>num_neighbors</strong>: The number of neighbors used for graph regularization. This value has to be less than or equal to the <code translate="no" dir="ltr">max_nbrs</code> argument used above when invoking <a href="https://www.tensorflow.org/neural_structured_learning/api_docs/python/nsl/tools/pack_nbrs"><code translate="no" dir="ltr">nsl.tools.pack_nbrs</code></a>.</p></li> <li><p><strong>num_fc_units</strong>: The number of units in the fully connected layer of the neural network.</p></li> <li><p><strong>train_epochs</strong>: The number of training epochs.</p></li> <li><p><strong>batch_size</strong>: Batch size used for training and evaluation.</p></li> <li><p><strong>eval_steps</strong>: The number of batches to process before deeming evaluation is complete. If set to <code translate="no" dir="ltr">None</code>, all instances in the test set are evaluated.</p></li> </ul> <div></div><devsite-code><pre class="devsite-click-to-copy" translate="no" dir="ltr" is-upgraded syntax="Python"><code translate="no" dir="ltr"><span class="devsite-syntax-k">class</span> <span class="devsite-syntax-nc">HParams</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-nb">object</span><span class="devsite-syntax-p">):</span> <span class="devsite-syntax-w"> </span><span class="devsite-syntax-sd">"""Hyperparameters used for training."""</span> <span class="devsite-syntax-k">def</span> <span class="devsite-syntax-fm">__init__</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-bp">self</span><span class="devsite-syntax-p">):</span> <span class="devsite-syntax-c1">### dataset parameters</span> <span class="devsite-syntax-bp">self</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">num_classes</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-mi">2</span> <span class="devsite-syntax-bp">self</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">max_seq_length</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-mi">256</span> <span class="devsite-syntax-bp">self</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">vocab_size</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-mi">10000</span> <span class="devsite-syntax-c1">### neural graph learning parameters</span> <span class="devsite-syntax-bp">self</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">distance_type</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">nsl</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">configs</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">DistanceType</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">L2</span> <span class="devsite-syntax-bp">self</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">graph_regularization_multiplier</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-mf">0.1</span> <span class="devsite-syntax-bp">self</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">num_neighbors</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-mi">2</span> <span class="devsite-syntax-c1">### model architecture</span> <span class="devsite-syntax-bp">self</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">num_embedding_dims</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-mi">16</span> <span class="devsite-syntax-bp">self</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">num_lstm_dims</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-mi">64</span> <span class="devsite-syntax-bp">self</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">num_fc_units</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-mi">64</span> <span class="devsite-syntax-c1">### training parameters</span> <span class="devsite-syntax-bp">self</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">train_epochs</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-mi">10</span> <span class="devsite-syntax-bp">self</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">batch_size</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-mi">128</span> <span class="devsite-syntax-c1">### eval parameters</span> <span class="devsite-syntax-bp">self</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">eval_steps</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-kc">None</span> <span class="devsite-syntax-c1"># All instances in the test set are evaluated.</span> <span class="devsite-syntax-n">HPARAMS</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">HParams</span><span class="devsite-syntax-p">()</span> </code></pre></devsite-code> <h3 id="prepare_the_data" data-text="Prepare the data" tabindex="-1">Prepare the data</h3> <p>The reviews—the arrays of integers—must be converted to tensors before being fed into the neural network. This conversion can be done a couple of ways:</p> <ul> <li><p>Convert the arrays into vectors of <code translate="no" dir="ltr">0</code>s and <code translate="no" dir="ltr">1</code>s indicating word occurrence, similar to a one-hot encoding. For example, the sequence <code translate="no" dir="ltr">[3, 5]</code> would become a <code translate="no" dir="ltr">10000</code>-dimensional vector that is all zeros except for indices <code translate="no" dir="ltr">3</code> and <code translate="no" dir="ltr">5</code>, which are ones. Then, make this the first layer in our network—a <code translate="no" dir="ltr">Dense</code> layer—that can handle floating point vector data. This approach is memory intensive, though, requiring a <code translate="no" dir="ltr">num_words * num_reviews</code> size matrix.</p></li> <li><p>Alternatively, we can pad the arrays so they all have the same length, then create an integer tensor of shape <code translate="no" dir="ltr">max_length * num_reviews</code>. We can use an embedding layer capable of handling this shape as the first layer in our network.</p></li> </ul> <p>In this tutorial, we will use the second approach.</p> <p>Since the movie reviews must be the same length, we will use the <code translate="no" dir="ltr">pad_sequence</code> function defined below to standardize the lengths.</p> <div></div><devsite-code><pre class="devsite-click-to-copy" translate="no" dir="ltr" is-upgraded syntax="Python"><code translate="no" dir="ltr"><span class="devsite-syntax-k">def</span> <span class="devsite-syntax-nf">make_dataset</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">file_path</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">training</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-kc">False</span><span class="devsite-syntax-p">):</span> <span class="devsite-syntax-w"> </span><span class="devsite-syntax-sd">"""Creates a `tf.data.TFRecordDataset`.</span> <span class="devsite-syntax-sd"> Args:</span> <span class="devsite-syntax-sd"> file_path: Name of the file in the `.tfrecord` format containing</span> <span class="devsite-syntax-sd"> `tf.train.Example` objects.</span> <span class="devsite-syntax-sd"> training: Boolean indicating if we are in training mode.</span> <span class="devsite-syntax-sd"> Returns:</span> <span class="devsite-syntax-sd"> An instance of `tf.data.TFRecordDataset` containing the `tf.train.Example`</span> <span class="devsite-syntax-sd"> objects.</span> <span class="devsite-syntax-sd"> """</span> <span class="devsite-syntax-k">def</span> <span class="devsite-syntax-nf">pad_sequence</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">sequence</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">max_seq_length</span><span class="devsite-syntax-p">):</span> <span class="devsite-syntax-w"> </span><span class="devsite-syntax-sd">"""Pads the input sequence (a `tf.SparseTensor`) to `max_seq_length`."""</span> <span class="devsite-syntax-n">pad_size</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">maximum</span><span class="devsite-syntax-p">([</span><span class="devsite-syntax-mi">0</span><span class="devsite-syntax-p">],</span> <span class="devsite-syntax-n">max_seq_length</span> <span class="devsite-syntax-o">-</span> <span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">shape</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">sequence</span><span class="devsite-syntax-p">)[</span><span class="devsite-syntax-mi">0</span><span class="devsite-syntax-p">])</span> <span class="devsite-syntax-n">padded</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">concat</span><span class="devsite-syntax-p">(</span> <span class="devsite-syntax-p">[</span><span class="devsite-syntax-n">sequence</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">values</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">fill</span><span class="devsite-syntax-p">((</span><span class="devsite-syntax-n">pad_size</span><span class="devsite-syntax-p">),</span> <span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">cast</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-mi">0</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">sequence</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">dtype</span><span class="devsite-syntax-p">))],</span> <span class="devsite-syntax-n">axis</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-mi">0</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-c1"># The input sequence may be larger than max_seq_length. Truncate down if</span> <span class="devsite-syntax-c1"># necessary.</span> <span class="devsite-syntax-k">return</span> <span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">slice</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">padded</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-p">[</span><span class="devsite-syntax-mi">0</span><span class="devsite-syntax-p">],</span> <span class="devsite-syntax-p">[</span><span class="devsite-syntax-n">max_seq_length</span><span class="devsite-syntax-p">])</span> <span class="devsite-syntax-k">def</span> <span class="devsite-syntax-nf">parse_example</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">example_proto</span><span class="devsite-syntax-p">):</span> <span class="devsite-syntax-w"> </span><span class="devsite-syntax-sd">"""Extracts relevant fields from the `example_proto`.</span> <span class="devsite-syntax-sd"> Args:</span> <span class="devsite-syntax-sd"> example_proto: An instance of `tf.train.Example`.</span> <span class="devsite-syntax-sd"> Returns:</span> <span class="devsite-syntax-sd"> A pair whose first value is a dictionary containing relevant features</span> <span class="devsite-syntax-sd"> and whose second value contains the ground truth labels.</span> <span class="devsite-syntax-sd"> """</span> <span class="devsite-syntax-c1"># The 'words' feature is a variable length word ID vector.</span> <span class="devsite-syntax-n">feature_spec</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-p">{</span> <span class="devsite-syntax-s1">'words'</span><span class="devsite-syntax-p">:</span> <span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">io</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">VarLenFeature</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">int64</span><span class="devsite-syntax-p">),</span> <span class="devsite-syntax-s1">'label'</span><span class="devsite-syntax-p">:</span> <span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">io</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">FixedLenFeature</span><span class="devsite-syntax-p">((),</span> <span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">int64</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">default_value</span><span class="devsite-syntax-o">=-</span><span class="devsite-syntax-mi">1</span><span class="devsite-syntax-p">),</span> <span class="devsite-syntax-p">}</span> <span class="devsite-syntax-c1"># We also extract corresponding neighbor features in a similar manner to</span> <span class="devsite-syntax-c1"># the features above during training.</span> <span class="devsite-syntax-k">if</span> <span class="devsite-syntax-n">training</span><span class="devsite-syntax-p">:</span> <span class="devsite-syntax-k">for</span> <span class="devsite-syntax-n">i</span> <span class="devsite-syntax-ow">in</span> <span class="devsite-syntax-nb">range</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">HPARAMS</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">num_neighbors</span><span class="devsite-syntax-p">):</span> <span class="devsite-syntax-n">nbr_feature_key</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-s1">'</span><span class="devsite-syntax-si">{}{}</span><span class="devsite-syntax-s1">_</span><span class="devsite-syntax-si">{}</span><span class="devsite-syntax-s1">'</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">format</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">NBR_FEATURE_PREFIX</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">i</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-s1">'words'</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">nbr_weight_key</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-s1">'</span><span class="devsite-syntax-si">{}{}{}</span><span class="devsite-syntax-s1">'</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">format</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">NBR_FEATURE_PREFIX</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">i</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">NBR_WEIGHT_SUFFIX</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">feature_spec</span><span class="devsite-syntax-p">[</span><span class="devsite-syntax-n">nbr_feature_key</span><span class="devsite-syntax-p">]</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">io</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">VarLenFeature</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">int64</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-c1"># We assign a default value of 0.0 for the neighbor weight so that</span> <span class="devsite-syntax-c1"># graph regularization is done on samples based on their exact number</span> <span class="devsite-syntax-c1"># of neighbors. In other words, non-existent neighbors are discounted.</span> <span class="devsite-syntax-n">feature_spec</span><span class="devsite-syntax-p">[</span><span class="devsite-syntax-n">nbr_weight_key</span><span class="devsite-syntax-p">]</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">io</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">FixedLenFeature</span><span class="devsite-syntax-p">(</span> <span class="devsite-syntax-p">[</span><span class="devsite-syntax-mi">1</span><span class="devsite-syntax-p">],</span> <span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">float32</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">default_value</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">constant</span><span class="devsite-syntax-p">([</span><span class="devsite-syntax-mf">0.0</span><span class="devsite-syntax-p">]))</span> <span class="devsite-syntax-n">features</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">io</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">parse_single_example</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">example_proto</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">feature_spec</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-c1"># Since the 'words' feature is a variable length word vector, we pad it to a</span> <span class="devsite-syntax-c1"># constant maximum length based on HPARAMS.max_seq_length</span> <span class="devsite-syntax-n">features</span><span class="devsite-syntax-p">[</span><span class="devsite-syntax-s1">'words'</span><span class="devsite-syntax-p">]</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">pad_sequence</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">features</span><span class="devsite-syntax-p">[</span><span class="devsite-syntax-s1">'words'</span><span class="devsite-syntax-p">],</span> <span class="devsite-syntax-n">HPARAMS</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">max_seq_length</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-k">if</span> <span class="devsite-syntax-n">training</span><span class="devsite-syntax-p">:</span> <span class="devsite-syntax-k">for</span> <span class="devsite-syntax-n">i</span> <span class="devsite-syntax-ow">in</span> <span class="devsite-syntax-nb">range</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">HPARAMS</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">num_neighbors</span><span class="devsite-syntax-p">):</span> <span class="devsite-syntax-n">nbr_feature_key</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-s1">'</span><span class="devsite-syntax-si">{}{}</span><span class="devsite-syntax-s1">_</span><span class="devsite-syntax-si">{}</span><span class="devsite-syntax-s1">'</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">format</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">NBR_FEATURE_PREFIX</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">i</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-s1">'words'</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">features</span><span class="devsite-syntax-p">[</span><span class="devsite-syntax-n">nbr_feature_key</span><span class="devsite-syntax-p">]</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">pad_sequence</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">features</span><span class="devsite-syntax-p">[</span><span class="devsite-syntax-n">nbr_feature_key</span><span class="devsite-syntax-p">],</span> <span class="devsite-syntax-n">HPARAMS</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">max_seq_length</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">labels</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">features</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">pop</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-s1">'label'</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-k">return</span> <span class="devsite-syntax-n">features</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">labels</span> <span class="devsite-syntax-n">dataset</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">data</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">TFRecordDataset</span><span class="devsite-syntax-p">([</span><span class="devsite-syntax-n">file_path</span><span class="devsite-syntax-p">])</span> <span class="devsite-syntax-k">if</span> <span class="devsite-syntax-n">training</span><span class="devsite-syntax-p">:</span> <span class="devsite-syntax-n">dataset</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">dataset</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">shuffle</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-mi">10000</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">dataset</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">dataset</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">map</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">parse_example</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">dataset</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">dataset</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">batch</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">HPARAMS</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">batch_size</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-k">return</span> <span class="devsite-syntax-n">dataset</span> <span class="devsite-syntax-n">train_dataset</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">make_dataset</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-s1">'/tmp/imdb/nsl_train_data.tfr'</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-kc">True</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">test_dataset</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">make_dataset</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-s1">'/tmp/imdb/test_data.tfr'</span><span class="devsite-syntax-p">)</span> </code></pre></devsite-code> <h3 id="build_the_model" data-text="Build the model" tabindex="-1">Build the model</h3> <p>A neural network is created by stacking layers—this requires two main architectural decisions:</p> <ul> <li>How many layers to use in the model?</li> <li>How many <em>hidden units</em> to use for each layer?</li> </ul> <p>In this example, the input data consists of an array of word-indices. The labels to predict are either 0 or 1.</p> <p>We will use a bi-directional LSTM as our base model in this tutorial.</p> <div></div><devsite-code><pre class="devsite-click-to-copy" translate="no" dir="ltr" is-upgraded syntax="Python"><code translate="no" dir="ltr"><span class="devsite-syntax-c1"># This function exists as an alternative to the bi-LSTM model used in this</span> <span class="devsite-syntax-c1"># notebook.</span> <span class="devsite-syntax-k">def</span> <span class="devsite-syntax-nf">make_feed_forward_model</span><span class="devsite-syntax-p">():</span> <span class="devsite-syntax-w"> </span><span class="devsite-syntax-sd">"""Builds a simple 2 layer feed forward neural network."""</span> <span class="devsite-syntax-n">inputs</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">keras</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">Input</span><span class="devsite-syntax-p">(</span> <span class="devsite-syntax-n">shape</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">HPARAMS</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">max_seq_length</span><span class="devsite-syntax-p">,),</span> <span class="devsite-syntax-n">dtype</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-s1">'int64'</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">name</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-s1">'words'</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">embedding_layer</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">keras</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">layers</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">Embedding</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">HPARAMS</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">vocab_size</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-mi">16</span><span class="devsite-syntax-p">)(</span><span class="devsite-syntax-n">inputs</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">pooling_layer</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">keras</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">layers</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">GlobalAveragePooling1D</span><span class="devsite-syntax-p">()(</span><span class="devsite-syntax-n">embedding_layer</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">dense_layer</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">keras</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">layers</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">Dense</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-mi">16</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">activation</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-s1">'relu'</span><span class="devsite-syntax-p">)(</span><span class="devsite-syntax-n">pooling_layer</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">outputs</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">keras</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">layers</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">Dense</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-mi">1</span><span class="devsite-syntax-p">)(</span><span class="devsite-syntax-n">dense_layer</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-k">return</span> <span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">keras</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">Model</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">inputs</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-n">inputs</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">outputs</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-n">outputs</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-k">def</span> <span class="devsite-syntax-nf">make_bilstm_model</span><span class="devsite-syntax-p">():</span> <span class="devsite-syntax-w"> </span><span class="devsite-syntax-sd">"""Builds a bi-directional LSTM model."""</span> <span class="devsite-syntax-n">inputs</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">keras</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">Input</span><span class="devsite-syntax-p">(</span> <span class="devsite-syntax-n">shape</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">HPARAMS</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">max_seq_length</span><span class="devsite-syntax-p">,),</span> <span class="devsite-syntax-n">dtype</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-s1">'int64'</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">name</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-s1">'words'</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">embedding_layer</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">keras</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">layers</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">Embedding</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">HPARAMS</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">vocab_size</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">HPARAMS</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">num_embedding_dims</span><span class="devsite-syntax-p">)(</span> <span class="devsite-syntax-n">inputs</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">lstm_layer</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">keras</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">layers</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">Bidirectional</span><span class="devsite-syntax-p">(</span> <span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">keras</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">layers</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">LSTM</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">HPARAMS</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">num_lstm_dims</span><span class="devsite-syntax-p">))(</span> <span class="devsite-syntax-n">embedding_layer</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">dense_layer</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">keras</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">layers</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">Dense</span><span class="devsite-syntax-p">(</span> <span class="devsite-syntax-n">HPARAMS</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">num_fc_units</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">activation</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-s1">'relu'</span><span class="devsite-syntax-p">)(</span> <span class="devsite-syntax-n">lstm_layer</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">outputs</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">keras</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">layers</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">Dense</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-mi">1</span><span class="devsite-syntax-p">)(</span><span class="devsite-syntax-n">dense_layer</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-k">return</span> <span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">keras</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">Model</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">inputs</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-n">inputs</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">outputs</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-n">outputs</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-c1"># Feel free to use an architecture of your choice.</span> <span class="devsite-syntax-n">model</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">make_bilstm_model</span><span class="devsite-syntax-p">()</span> <span class="devsite-syntax-n">model</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">summary</span><span class="devsite-syntax-p">()</span> </code></pre></devsite-code> <div></div><devsite-code><pre class="tfo-notebook-code-cell-output" translate="no" dir="ltr" is-upgraded> Model: "model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= words (InputLayer) [(None, 256)] 0 embedding (Embedding) (None, 256, 16) 160000 bidirectional (Bidirectiona (None, 128) 41472 l) dense (Dense) (None, 64) 8256 dense_1 (Dense) (None, 1) 65 ================================================================= Total params: 209,793 Trainable params: 209,793 Non-trainable params: 0 _________________________________________________________________ </pre></devsite-code> <p>The layers are effectively stacked sequentially to build the classifier:</p> <ol> <li>The first layer is an <code translate="no" dir="ltr">Input</code> layer which takes the integer-encoded vocabulary.</li> <li>The next layer is an <code translate="no" dir="ltr">Embedding</code> layer, which takes the integer-encoded vocabulary and looks up the embedding vector for each word-index. These vectors are learned as the model trains. The vectors add a dimension to the output array. The resulting dimensions are: <code translate="no" dir="ltr">(batch, sequence, embedding)</code>.</li> <li>Next, a bidirectional LSTM layer returns a fixed-length output vector for each example.</li> <li>This fixed-length output vector is piped through a fully-connected (<code translate="no" dir="ltr">Dense</code>) layer with 64 hidden units.</li> <li>The last layer is densely connected with a single output node. Using the <code translate="no" dir="ltr">sigmoid</code> activation function, this value is a float between 0 and 1, representing a probability, or confidence level.</li> </ol> <h3 id="hidden_units" data-text="Hidden units" tabindex="-1">Hidden units</h3> <p>The above model has two intermediate or "hidden" layers, between the input and output, and excluding the <code translate="no" dir="ltr">Embedding</code> layer. The number of outputs (units, nodes, or neurons) is the dimension of the representational space for the layer. In other words, the amount of freedom the network is allowed when learning an internal representation.</p> <p>If a model has more hidden units (a higher-dimensional representation space), and/or more layers, then the network can learn more complex representations. However, it makes the network more computationally expensive and may lead to learning unwanted patterns—patterns that improve performance on training data but not on the test data. This is called <em>overfitting</em>.</p> <h3 id="loss_function_and_optimizer" data-text="Loss function and optimizer" tabindex="-1">Loss function and optimizer</h3> <p>A model needs a loss function and an optimizer for training. Since this is a binary classification problem and the model outputs a probability (a single-unit layer with a sigmoid activation), we'll use the <code translate="no" dir="ltr">binary_crossentropy</code> loss function.</p> <div></div><devsite-code><pre class="devsite-click-to-copy" translate="no" dir="ltr" is-upgraded syntax="Python"><code translate="no" dir="ltr"><span class="devsite-syntax-n">model</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">compile</span><span class="devsite-syntax-p">(</span> <span class="devsite-syntax-n">optimizer</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-s1">'adam'</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">loss</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">keras</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">losses</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">BinaryCrossentropy</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">from_logits</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-kc">True</span><span class="devsite-syntax-p">),</span> <span class="devsite-syntax-n">metrics</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-p">[</span><span class="devsite-syntax-s1">'accuracy'</span><span class="devsite-syntax-p">])</span> </code></pre></devsite-code> <h3 id="create_a_validation_set" data-text="Create a validation set" tabindex="-1">Create a validation set</h3> <p>When training, we want to check the accuracy of the model on data it hasn't seen before. Create a <em>validation set</em> by setting apart a fraction of the original training data. (Why not use the testing set now? Our goal is to develop and tune our model using only the training data, then use the test data just once to evaluate our accuracy).</p> <p>In this tutorial, we take roughly 10% of the initial training samples (10% of 25000) as labeled data for training and the remaining as validation data. Since the initial train/test split was 50/50 (25000 samples each), the effective train/validation/test split we now have is 5/45/50.</p> <p>Note that 'train_dataset' has already been batched and shuffled. </p> <div></div><devsite-code><pre class="devsite-click-to-copy" translate="no" dir="ltr" is-upgraded syntax="Python"><code translate="no" dir="ltr"><span class="devsite-syntax-n">validation_fraction</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-mf">0.9</span> <span class="devsite-syntax-n">validation_size</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-nb">int</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">validation_fraction</span> <span class="devsite-syntax-o">*</span> <span class="devsite-syntax-nb">int</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">training_samples_count</span> <span class="devsite-syntax-o">/</span> <span class="devsite-syntax-n">HPARAMS</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">batch_size</span><span class="devsite-syntax-p">))</span> <span class="devsite-syntax-nb">print</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">validation_size</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">validation_dataset</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">train_dataset</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">take</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">validation_size</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">train_dataset</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">train_dataset</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">skip</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">validation_size</span><span class="devsite-syntax-p">)</span> </code></pre></devsite-code> <div></div><devsite-code><pre class="tfo-notebook-code-cell-output" translate="no" dir="ltr" is-upgraded> 175 </pre></devsite-code> <h3 id="train_the_model" data-text="Train the model" tabindex="-1">Train the model</h3> <p>Train the model in mini-batches. While training, monitor the model's loss and accuracy on the validation set:</p> <div></div><devsite-code><pre class="devsite-click-to-copy" translate="no" dir="ltr" is-upgraded syntax="Python"><code translate="no" dir="ltr"><span class="devsite-syntax-n">history</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">model</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">fit</span><span class="devsite-syntax-p">(</span> <span class="devsite-syntax-n">train_dataset</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">validation_data</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-n">validation_dataset</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">epochs</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-n">HPARAMS</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">train_epochs</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">verbose</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-mi">1</span><span class="devsite-syntax-p">)</span> </code></pre></devsite-code> <div></div><devsite-code><pre class="tfo-notebook-code-cell-output" translate="no" dir="ltr" is-upgraded> Epoch 1/10 /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/engine/functional.py:638: UserWarning: Input dict contained keys ['NL_nbr_0_words', 'NL_nbr_1_words', 'NL_nbr_0_weight', 'NL_nbr_1_weight'] which did not match any model input. They will be ignored by the model. inputs = self._flatten_to_reference_inputs(inputs) 21/21 [==============================] - 20s 790ms/step - loss: 0.6928 - accuracy: 0.4850 - val_loss: 0.6927 - val_accuracy: 0.5001 Epoch 2/10 21/21 [==============================] - 15s 739ms/step - loss: 0.6847 - accuracy: 0.5019 - val_loss: 0.6387 - val_accuracy: 0.5028 Epoch 3/10 21/21 [==============================] - 15s 741ms/step - loss: 0.6641 - accuracy: 0.5350 - val_loss: 0.6572 - val_accuracy: 0.5002 Epoch 4/10 21/21 [==============================] - 15s 740ms/step - loss: 0.6083 - accuracy: 0.5504 - val_loss: 0.5291 - val_accuracy: 0.7685 Epoch 5/10 21/21 [==============================] - 15s 742ms/step - loss: 0.4911 - accuracy: 0.7635 - val_loss: 0.4327 - val_accuracy: 0.8143 Epoch 6/10 21/21 [==============================] - 15s 741ms/step - loss: 0.3924 - accuracy: 0.8304 - val_loss: 0.3821 - val_accuracy: 0.8529 Epoch 7/10 21/21 [==============================] - 15s 746ms/step - loss: 0.3449 - accuracy: 0.8612 - val_loss: 0.3550 - val_accuracy: 0.8145 Epoch 8/10 21/21 [==============================] - 16s 753ms/step - loss: 0.2954 - accuracy: 0.8796 - val_loss: 0.3103 - val_accuracy: 0.8671 Epoch 9/10 21/21 [==============================] - 16s 767ms/step - loss: 0.3243 - accuracy: 0.8719 - val_loss: 0.3371 - val_accuracy: 0.8733 Epoch 10/10 21/21 [==============================] - 16s 768ms/step - loss: 0.2918 - accuracy: 0.8765 - val_loss: 0.2845 - val_accuracy: 0.8944 </pre></devsite-code> <h3 id="evaluate_the_model" data-text="Evaluate the model" tabindex="-1">Evaluate the model</h3> <p>Now, let's see how the model performs. Two values will be returned. Loss (a number which represents our error, lower values are better), and accuracy.</p> <div></div><devsite-code><pre class="devsite-click-to-copy" translate="no" dir="ltr" is-upgraded syntax="Python"><code translate="no" dir="ltr"><span class="devsite-syntax-n">results</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">model</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">evaluate</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">test_dataset</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">steps</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-n">HPARAMS</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">eval_steps</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-nb">print</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">results</span><span class="devsite-syntax-p">)</span> </code></pre></devsite-code> <div></div><devsite-code><pre class="tfo-notebook-code-cell-output" translate="no" dir="ltr" is-upgraded> 196/196 [==============================] - 14s 69ms/step - loss: 0.3740 - accuracy: 0.8502 [0.37399888038635254, 0.8502399921417236] </pre></devsite-code> <h3 id="create_a_graph_of_accuracyloss_over_time" data-text="Create a graph of accuracy/loss over time" tabindex="-1">Create a graph of accuracy/loss over time</h3> <p><code translate="no" dir="ltr">model.fit()</code> returns a <code translate="no" dir="ltr">History</code> object that contains a dictionary with everything that happened during training:</p> <div></div><devsite-code><pre class="devsite-click-to-copy" translate="no" dir="ltr" is-upgraded syntax="Python"><code translate="no" dir="ltr"><span class="devsite-syntax-n">history_dict</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">history</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">history</span> <span class="devsite-syntax-n">history_dict</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">keys</span><span class="devsite-syntax-p">()</span> </code></pre></devsite-code> <div></div><devsite-code><pre class="tfo-notebook-code-cell-output" translate="no" dir="ltr" is-upgraded> dict_keys(['loss', 'accuracy', 'val_loss', 'val_accuracy']) </pre></devsite-code> <p>There are four entries: one for each monitored metric during training and validation. We can use these to plot the training and validation loss for comparison, as well as the training and validation accuracy:</p> <div></div><devsite-code><pre class="devsite-click-to-copy" translate="no" dir="ltr" is-upgraded syntax="Python"><code translate="no" dir="ltr"><span class="devsite-syntax-n">acc</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">history_dict</span><span class="devsite-syntax-p">[</span><span class="devsite-syntax-s1">'accuracy'</span><span class="devsite-syntax-p">]</span> <span class="devsite-syntax-n">val_acc</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">history_dict</span><span class="devsite-syntax-p">[</span><span class="devsite-syntax-s1">'val_accuracy'</span><span class="devsite-syntax-p">]</span> <span class="devsite-syntax-n">loss</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">history_dict</span><span class="devsite-syntax-p">[</span><span class="devsite-syntax-s1">'loss'</span><span class="devsite-syntax-p">]</span> <span class="devsite-syntax-n">val_loss</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">history_dict</span><span class="devsite-syntax-p">[</span><span class="devsite-syntax-s1">'val_loss'</span><span class="devsite-syntax-p">]</span> <span class="devsite-syntax-n">epochs</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-nb">range</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-mi">1</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-nb">len</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">acc</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-o">+</span> <span class="devsite-syntax-mi">1</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-c1"># "-r^" is for solid red line with triangle markers.</span> <span class="devsite-syntax-n">plt</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">plot</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">epochs</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">loss</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-s1">'-r^'</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">label</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-s1">'Training loss'</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-c1"># "-b0" is for solid blue line with circle markers.</span> <span class="devsite-syntax-n">plt</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">plot</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">epochs</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">val_loss</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-s1">'-bo'</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">label</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-s1">'Validation loss'</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">plt</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">title</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-s1">'Training and validation loss'</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">plt</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">xlabel</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-s1">'Epochs'</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">plt</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">ylabel</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-s1">'Loss'</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">plt</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">legend</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">loc</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-s1">'best'</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">plt</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">show</span><span class="devsite-syntax-p">()</span> </code></pre></devsite-code> <p><img src="/static/neural_structured_learning/tutorials/graph_keras_lstm_imdb_files/output_nGoYf2Js-lle_0.png" alt="png"></p> <div></div><devsite-code><pre class="devsite-click-to-copy" translate="no" dir="ltr" is-upgraded syntax="Python"><code translate="no" dir="ltr"><span class="devsite-syntax-n">plt</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">clf</span><span class="devsite-syntax-p">()</span> <span class="devsite-syntax-c1"># clear figure</span> <span class="devsite-syntax-n">plt</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">plot</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">epochs</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">acc</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-s1">'-r^'</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">label</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-s1">'Training acc'</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">plt</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">plot</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">epochs</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">val_acc</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-s1">'-bo'</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">label</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-s1">'Validation acc'</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">plt</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">title</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-s1">'Training and validation accuracy'</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">plt</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">xlabel</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-s1">'Epochs'</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">plt</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">ylabel</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-s1">'Accuracy'</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">plt</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">legend</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">loc</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-s1">'best'</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">plt</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">show</span><span class="devsite-syntax-p">()</span> </code></pre></devsite-code> <p><img src="/static/neural_structured_learning/tutorials/graph_keras_lstm_imdb_files/output_6hXx-xOv-llh_0.png" alt="png"></p> <p>Notice the training loss <em>decreases</em> with each epoch and the training accuracy <em>increases</em> with each epoch. This is expected when using a gradient descent optimization—it should minimize the desired quantity on every iteration.</p> <h2 id="graph_regularization" data-text="Graph regularization" tabindex="-1">Graph regularization</h2> <p>We are now ready to try graph regularization using the base model that we built above. We will use the <code translate="no" dir="ltr">GraphRegularization</code> wrapper class provided by the Neural Structured Learning framework to wrap the base (bi-LSTM) model to include graph regularization. The rest of the steps for training and evaluating the graph-regularized model are similar to that of the base model.</p> <h3 id="create_graph-regularized_model" data-text="Create graph-regularized model" tabindex="-1">Create graph-regularized model</h3> <p>To assess the incremental benefit of graph regularization, we will create a new base model instance. This is because <code translate="no" dir="ltr">model</code> has already been trained for a few iterations, and reusing this trained model to create a graph-regularized model will not be a fair comparison for <code translate="no" dir="ltr">model</code>.</p> <div></div><devsite-code><pre class="devsite-click-to-copy" translate="no" dir="ltr" is-upgraded syntax="Python"><code translate="no" dir="ltr"><span class="devsite-syntax-c1"># Build a new base LSTM model.</span> <span class="devsite-syntax-n">base_reg_model</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">make_bilstm_model</span><span class="devsite-syntax-p">()</span> </code></pre></devsite-code><div></div><devsite-code><pre class="devsite-click-to-copy" translate="no" dir="ltr" is-upgraded syntax="Python"><code translate="no" dir="ltr"><span class="devsite-syntax-c1"># Wrap the base model with graph regularization.</span> <span class="devsite-syntax-n">graph_reg_config</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">nsl</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">configs</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">make_graph_reg_config</span><span class="devsite-syntax-p">(</span> <span class="devsite-syntax-n">max_neighbors</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-n">HPARAMS</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">num_neighbors</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">multiplier</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-n">HPARAMS</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">graph_regularization_multiplier</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">distance_type</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-n">HPARAMS</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">distance_type</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">sum_over_axis</span><span class="devsite-syntax-o">=-</span><span class="devsite-syntax-mi">1</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">graph_reg_model</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">nsl</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">keras</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">GraphRegularization</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">base_reg_model</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">graph_reg_config</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">graph_reg_model</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">compile</span><span class="devsite-syntax-p">(</span> <span class="devsite-syntax-n">optimizer</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-s1">'adam'</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">loss</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-n">tf</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">keras</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">losses</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">BinaryCrossentropy</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">from_logits</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-kc">True</span><span class="devsite-syntax-p">),</span> <span class="devsite-syntax-n">metrics</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-p">[</span><span class="devsite-syntax-s1">'accuracy'</span><span class="devsite-syntax-p">])</span> </code></pre></devsite-code> <h3 id="train_the_model_2" data-text="Train the model" tabindex="-1">Train the model</h3> <div></div><devsite-code><pre class="devsite-click-to-copy" translate="no" dir="ltr" is-upgraded syntax="Python"><code translate="no" dir="ltr"><span class="devsite-syntax-n">graph_reg_history</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">graph_reg_model</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">fit</span><span class="devsite-syntax-p">(</span> <span class="devsite-syntax-n">train_dataset</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">validation_data</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-n">validation_dataset</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">epochs</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-n">HPARAMS</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">train_epochs</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">verbose</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-mi">1</span><span class="devsite-syntax-p">)</span> </code></pre></devsite-code> <div></div><devsite-code><pre class="tfo-notebook-code-cell-output" translate="no" dir="ltr" is-upgraded> Epoch 1/10 WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23. Instructions for updating: Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089 WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/autograph/pyct/static_analysis/liveness.py:83: Analyzer.lamba_check (from tensorflow.python.autograph.pyct.static_analysis.liveness) is deprecated and will be removed after 2023-09-23. Instructions for updating: Lambda fuctions will be no more assumed to be used in the statement where they are used, or at least in the same block. https://github.com/tensorflow/tensorflow/issues/56089 21/21 [==============================] - 27s 920ms/step - loss: 0.6938 - accuracy: 0.4858 - scaled_graph_loss: 3.3994e-05 - val_loss: 0.6928 - val_accuracy: 0.5024 Epoch 2/10 21/21 [==============================] - 17s 836ms/step - loss: 0.6921 - accuracy: 0.5085 - scaled_graph_loss: 2.2528e-05 - val_loss: 0.6916 - val_accuracy: 0.4987 Epoch 3/10 21/21 [==============================] - 18s 844ms/step - loss: 0.6806 - accuracy: 0.5088 - scaled_graph_loss: 0.0018 - val_loss: 0.6383 - val_accuracy: 0.6404 Epoch 4/10 21/21 [==============================] - 17s 837ms/step - loss: 0.6143 - accuracy: 0.6588 - scaled_graph_loss: 0.0292 - val_loss: 0.5993 - val_accuracy: 0.5436 Epoch 5/10 21/21 [==============================] - 17s 841ms/step - loss: 0.5748 - accuracy: 0.7015 - scaled_graph_loss: 0.0563 - val_loss: 0.4726 - val_accuracy: 0.8239 Epoch 6/10 21/21 [==============================] - 18s 847ms/step - loss: 0.5366 - accuracy: 0.8019 - scaled_graph_loss: 0.0681 - val_loss: 0.4708 - val_accuracy: 0.7508 Epoch 7/10 21/21 [==============================] - 18s 847ms/step - loss: 0.5330 - accuracy: 0.7992 - scaled_graph_loss: 0.0722 - val_loss: 0.4462 - val_accuracy: 0.8373 Epoch 8/10 21/21 [==============================] - 18s 848ms/step - loss: 0.5207 - accuracy: 0.8096 - scaled_graph_loss: 0.0755 - val_loss: 0.4772 - val_accuracy: 0.7738 Epoch 9/10 21/21 [==============================] - 18s 851ms/step - loss: 0.5139 - accuracy: 0.8319 - scaled_graph_loss: 0.0831 - val_loss: 0.4223 - val_accuracy: 0.8412 Epoch 10/10 21/21 [==============================] - 18s 851ms/step - loss: 0.4959 - accuracy: 0.8377 - scaled_graph_loss: 0.0813 - val_loss: 0.4332 - val_accuracy: 0.8199 </pre></devsite-code> <h3 id="evaluate_the_model_2" data-text="Evaluate the model" tabindex="-1">Evaluate the model</h3> <div></div><devsite-code><pre class="devsite-click-to-copy" translate="no" dir="ltr" is-upgraded syntax="Python"><code translate="no" dir="ltr"><span class="devsite-syntax-n">graph_reg_results</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">graph_reg_model</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">evaluate</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">test_dataset</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">steps</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-n">HPARAMS</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">eval_steps</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-nb">print</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">graph_reg_results</span><span class="devsite-syntax-p">)</span> </code></pre></devsite-code> <div></div><devsite-code><pre class="tfo-notebook-code-cell-output" translate="no" dir="ltr" is-upgraded> 196/196 [==============================] - 15s 70ms/step - loss: 0.4728 - accuracy: 0.7732 [0.4728052020072937, 0.7731599807739258] </pre></devsite-code> <h3 id="create_a_graph_of_accuracyloss_over_time_2" data-text="Create a graph of accuracy/loss over time" tabindex="-1">Create a graph of accuracy/loss over time</h3> <div></div><devsite-code><pre class="devsite-click-to-copy" translate="no" dir="ltr" is-upgraded syntax="Python"><code translate="no" dir="ltr"><span class="devsite-syntax-n">graph_reg_history_dict</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">graph_reg_history</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">history</span> <span class="devsite-syntax-n">graph_reg_history_dict</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">keys</span><span class="devsite-syntax-p">()</span> </code></pre></devsite-code> <div></div><devsite-code><pre class="tfo-notebook-code-cell-output" translate="no" dir="ltr" is-upgraded> dict_keys(['loss', 'accuracy', 'scaled_graph_loss', 'val_loss', 'val_accuracy']) </pre></devsite-code> <p>There are five entries in total in the dictionary: training loss, training accuracy, training graph loss, validation loss, and validation accuracy. We can plot them all together for comparison. Note that the graph loss is only computed during training.</p> <div></div><devsite-code><pre class="devsite-click-to-copy" translate="no" dir="ltr" is-upgraded syntax="Python"><code translate="no" dir="ltr"><span class="devsite-syntax-n">acc</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">graph_reg_history_dict</span><span class="devsite-syntax-p">[</span><span class="devsite-syntax-s1">'accuracy'</span><span class="devsite-syntax-p">]</span> <span class="devsite-syntax-n">val_acc</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">graph_reg_history_dict</span><span class="devsite-syntax-p">[</span><span class="devsite-syntax-s1">'val_accuracy'</span><span class="devsite-syntax-p">]</span> <span class="devsite-syntax-n">loss</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">graph_reg_history_dict</span><span class="devsite-syntax-p">[</span><span class="devsite-syntax-s1">'loss'</span><span class="devsite-syntax-p">]</span> <span class="devsite-syntax-n">graph_loss</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">graph_reg_history_dict</span><span class="devsite-syntax-p">[</span><span class="devsite-syntax-s1">'scaled_graph_loss'</span><span class="devsite-syntax-p">]</span> <span class="devsite-syntax-n">val_loss</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">graph_reg_history_dict</span><span class="devsite-syntax-p">[</span><span class="devsite-syntax-s1">'val_loss'</span><span class="devsite-syntax-p">]</span> <span class="devsite-syntax-n">epochs</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-nb">range</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-mi">1</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-nb">len</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">acc</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-o">+</span> <span class="devsite-syntax-mi">1</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">plt</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">clf</span><span class="devsite-syntax-p">()</span> <span class="devsite-syntax-c1"># clear figure</span> <span class="devsite-syntax-c1"># "-r^" is for solid red line with triangle markers.</span> <span class="devsite-syntax-n">plt</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">plot</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">epochs</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">loss</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-s1">'-r^'</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">label</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-s1">'Training loss'</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-c1"># "-gD" is for solid green line with diamond markers.</span> <span class="devsite-syntax-n">plt</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">plot</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">epochs</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">graph_loss</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-s1">'-gD'</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">label</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-s1">'Training graph loss'</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-c1"># "-b0" is for solid blue line with circle markers.</span> <span class="devsite-syntax-n">plt</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">plot</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">epochs</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">val_loss</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-s1">'-bo'</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">label</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-s1">'Validation loss'</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">plt</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">title</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-s1">'Training and validation loss'</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">plt</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">xlabel</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-s1">'Epochs'</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">plt</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">ylabel</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-s1">'Loss'</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">plt</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">legend</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">loc</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-s1">'best'</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">plt</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">show</span><span class="devsite-syntax-p">()</span> </code></pre></devsite-code> <p><img src="/static/neural_structured_learning/tutorials/graph_keras_lstm_imdb_files/output_YhjhH4n_aprb_0.png" alt="png"></p> <div></div><devsite-code><pre class="devsite-click-to-copy" translate="no" dir="ltr" is-upgraded syntax="Python"><code translate="no" dir="ltr"><span class="devsite-syntax-n">plt</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">clf</span><span class="devsite-syntax-p">()</span> <span class="devsite-syntax-c1"># clear figure</span> <span class="devsite-syntax-n">plt</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">plot</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">epochs</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">acc</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-s1">'-r^'</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">label</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-s1">'Training acc'</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">plt</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">plot</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">epochs</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">val_acc</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-s1">'-bo'</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">label</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-s1">'Validation acc'</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">plt</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">title</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-s1">'Training and validation accuracy'</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">plt</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">xlabel</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-s1">'Epochs'</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">plt</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">ylabel</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-s1">'Accuracy'</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">plt</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">legend</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">loc</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-s1">'best'</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">plt</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">show</span><span class="devsite-syntax-p">()</span> </code></pre></devsite-code> <p><img src="/static/neural_structured_learning/tutorials/graph_keras_lstm_imdb_files/output_NE0vcDiqa1Id_0.png" alt="png"></p> <h2 id="the_power_of_semi-supervised_learning" data-text="The power of semi-supervised learning" tabindex="-1">The power of semi-supervised learning</h2> <p>Semi-supervised learning and more specifically, graph regularization in the context of this tutorial, can be really powerful when the amount of training data is small. The lack of training data is compensated by leveraging similarity among the training samples, which is not possible in traditional supervised learning.</p> <p>We define <strong><em>supervision ratio</em></strong> as the ratio of training samples to the total number of samples which includes training, validation, and test samples. In this notebook, we have used a supervision ratio of 0.05 (i.e, 5% of the labeled data) for training both the base model as well as the graph-regularized model. We illustrate the impact of the supervision ratio on model accuracy in the cell below.</p> <div></div><devsite-code><pre class="devsite-click-to-copy" translate="no" dir="ltr" is-upgraded syntax="Python"><code translate="no" dir="ltr"><span class="devsite-syntax-c1"># Accuracy values for both the Bi-LSTM model and the feed forward NN model have</span> <span class="devsite-syntax-c1"># been precomputed for the following supervision ratios.</span> <span class="devsite-syntax-n">supervision_ratios</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-p">[</span><span class="devsite-syntax-mf">0.3</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-mf">0.15</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-mf">0.05</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-mf">0.03</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-mf">0.02</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-mf">0.01</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-mf">0.005</span><span class="devsite-syntax-p">]</span> <span class="devsite-syntax-n">model_tags</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-p">[</span><span class="devsite-syntax-s1">'Bi-LSTM model'</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-s1">'Feed Forward NN model'</span><span class="devsite-syntax-p">]</span> <span class="devsite-syntax-n">base_model_accs</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-p">[[</span><span class="devsite-syntax-mi">84</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-mi">84</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-mi">83</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-mi">80</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-mi">65</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-mi">52</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-mi">50</span><span class="devsite-syntax-p">],</span> <span class="devsite-syntax-p">[</span><span class="devsite-syntax-mi">87</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-mi">86</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-mi">76</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-mi">74</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-mi">67</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-mi">52</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-mi">51</span><span class="devsite-syntax-p">]]</span> <span class="devsite-syntax-n">graph_reg_model_accs</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-p">[[</span><span class="devsite-syntax-mi">84</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-mi">84</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-mi">83</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-mi">83</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-mi">65</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-mi">63</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-mi">50</span><span class="devsite-syntax-p">],</span> <span class="devsite-syntax-p">[</span><span class="devsite-syntax-mi">87</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-mi">86</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-mi">80</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-mi">75</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-mi">67</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-mi">52</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-mi">50</span><span class="devsite-syntax-p">]]</span> <span class="devsite-syntax-n">plt</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">clf</span><span class="devsite-syntax-p">()</span> <span class="devsite-syntax-c1"># clear figure</span> <span class="devsite-syntax-n">fig</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">axes</span> <span class="devsite-syntax-o">=</span> <span class="devsite-syntax-n">plt</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">subplots</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-mi">1</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-mi">2</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">fig</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">set_size_inches</span><span class="devsite-syntax-p">((</span><span class="devsite-syntax-mi">12</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-mi">5</span><span class="devsite-syntax-p">))</span> <span class="devsite-syntax-k">for</span> <span class="devsite-syntax-n">ax</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">model_tag</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">base_model_acc</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">graph_reg_model_acc</span> <span class="devsite-syntax-ow">in</span> <span class="devsite-syntax-nb">zip</span><span class="devsite-syntax-p">(</span> <span class="devsite-syntax-n">axes</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">model_tags</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">base_model_accs</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">graph_reg_model_accs</span><span class="devsite-syntax-p">):</span> <span class="devsite-syntax-c1"># "-r^" is for solid red line with triangle markers.</span> <span class="devsite-syntax-n">ax</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">plot</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">base_model_acc</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-s1">'-r^'</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">label</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-s1">'Base model'</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-c1"># "-gD" is for solid green line with diamond markers.</span> <span class="devsite-syntax-n">ax</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">plot</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">graph_reg_model_acc</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-s1">'-gD'</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-n">label</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-s1">'Graph-regularized model'</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">ax</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">set_title</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">model_tag</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">ax</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">set_xlabel</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-s1">'Supervision ratio'</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">ax</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">set_ylabel</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-s1">'Accuracy(%)'</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">ax</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">set_ylim</span><span class="devsite-syntax-p">((</span><span class="devsite-syntax-mi">25</span><span class="devsite-syntax-p">,</span> <span class="devsite-syntax-mi">100</span><span class="devsite-syntax-p">))</span> <span class="devsite-syntax-n">ax</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">set_xticks</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-nb">range</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-nb">len</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">supervision_ratios</span><span class="devsite-syntax-p">)))</span> <span class="devsite-syntax-n">ax</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">set_xticklabels</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">supervision_ratios</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">ax</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">legend</span><span class="devsite-syntax-p">(</span><span class="devsite-syntax-n">loc</span><span class="devsite-syntax-o">=</span><span class="devsite-syntax-s1">'best'</span><span class="devsite-syntax-p">)</span> <span class="devsite-syntax-n">plt</span><span class="devsite-syntax-o">.</span><span class="devsite-syntax-n">show</span><span class="devsite-syntax-p">()</span> </code></pre></devsite-code> <div></div><devsite-code><pre class="tfo-notebook-code-cell-output" translate="no" dir="ltr" is-upgraded> <Figure size 640x480 with 0 Axes> </pre></devsite-code> <p><img src="/static/neural_structured_learning/tutorials/graph_keras_lstm_imdb_files/output_nWWa384R5vSm_1.png" alt="png"></p> <p>It can be observed that as the superivision ratio decreases, model accuracy also decreases. This is true for both the base model and for the graph-regularized model, regardless of the model architecture used. However, notice that the graph-regularized model performs better than the base model for both the architectures. In particular, for the Bi-LSTM model, when the supervision ratio is 0.01, the accuracy of the graph-regularized model is <strong>~20%</strong> higher than that of the base model. This is primarily because of semi-supervised learning for the graph-regularized model, where structural similarity among training samples is used in addition to the training samples themselves.</p> <h2 id="conclusion" data-text="Conclusion" tabindex="-1">Conclusion</h2> <p>We have demonstrated the use of graph regularization using the Neural Structured Learning (NSL) framework even when the input does not contain an explicit graph. We considered the task of sentiment classification of IMDB movie reviews for which we synthesized a similarity graph based on review embeddings. We encourage users to experiment further by varying hyperparameters, the amount of supervision, and by using different model architectures.</p> </div> <devsite-thumb-rating position="footer"> </devsite-thumb-rating> <div class="devsite-floating-action-buttons"> </div> </article> <devsite-content-footer class="nocontent"> <p>Except as otherwise noted, the content of this page is licensed under the <a href="https://creativecommons.org/licenses/by/4.0/">Creative Commons Attribution 4.0 License</a>, and code samples are licensed under the <a href="https://www.apache.org/licenses/LICENSE-2.0">Apache 2.0 License</a>. For details, see the <a href="https://developers.google.com/site-policies">Google Developers Site Policies</a>. Java is a registered trademark of Oracle and/or its affiliates.</p> <p>Last updated 2023-05-27 UTC.</p> </devsite-content-footer> <devsite-notification > </devsite-notification> <div class="devsite-content-data"> <template class="devsite-content-data-template"> [[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2023-05-27 UTC."],[],[]] </template> </div> </devsite-content> </main> <devsite-footer-promos class="devsite-footer"> </devsite-footer-promos> <devsite-footer-linkboxes class="devsite-footer"> <nav class="devsite-footer-linkboxes nocontent" aria-label="Footer links"> <ul class="devsite-footer-linkboxes-list"> <li class="devsite-footer-linkbox "> <h3 class="devsite-footer-linkbox-heading no-link">Stay connected</h3> <ul class="devsite-footer-linkbox-list"> <li class="devsite-footer-linkbox-item"> <a href="//blog.tensorflow.org" class="devsite-footer-linkbox-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Footer Link (index 1)" > Blog </a> </li> <li class="devsite-footer-linkbox-item"> <a href="//discuss.tensorflow.org" class="devsite-footer-linkbox-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Footer Link (index 2)" > Forum </a> </li> <li class="devsite-footer-linkbox-item"> <a href="//github.com/tensorflow/" class="devsite-footer-linkbox-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Footer Link (index 3)" > GitHub </a> </li> <li class="devsite-footer-linkbox-item"> <a href="//twitter.com/tensorflow" class="devsite-footer-linkbox-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Footer Link (index 4)" > Twitter </a> </li> <li class="devsite-footer-linkbox-item"> <a href="//youtube.com/tensorflow" class="devsite-footer-linkbox-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Footer Link (index 5)" > YouTube </a> </li> </ul> </li> <li class="devsite-footer-linkbox "> <h3 class="devsite-footer-linkbox-heading no-link">Support</h3> <ul class="devsite-footer-linkbox-list"> <li class="devsite-footer-linkbox-item"> <a href="//github.com/tensorflow/tensorflow/issues" class="devsite-footer-linkbox-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Footer Link (index 1)" > Issue tracker </a> </li> <li class="devsite-footer-linkbox-item"> <a href="//github.com/tensorflow/tensorflow/blob/master/RELEASE.md" class="devsite-footer-linkbox-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Footer Link (index 2)" > Release notes </a> </li> <li class="devsite-footer-linkbox-item"> <a href="//stackoverflow.com/questions/tagged/tensorflow" class="devsite-footer-linkbox-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Footer Link (index 3)" > Stack Overflow </a> </li> <li class="devsite-footer-linkbox-item"> <a href="/extras/tensorflow_brand_guidelines.pdf" class="devsite-footer-linkbox-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Footer Link (index 4)" > Brand guidelines </a> </li> <li class="devsite-footer-linkbox-item"> <a href="/about/bib" class="devsite-footer-linkbox-link gc-analytics-event" data-category="Site-Wide Custom Events" data-label="Footer Link (index 5)" > Cite TensorFlow </a> </li> </ul> </li> </ul> </nav> </devsite-footer-linkboxes> <devsite-footer-utility class="devsite-footer"> <div class="devsite-footer-utility nocontent"> <nav class="devsite-footer-utility-links" aria-label="Utility links"> <ul class="devsite-footer-utility-list"> <li class="devsite-footer-utility-item "> <a class="devsite-footer-utility-link gc-analytics-event" href="//policies.google.com/terms" data-category="Site-Wide Custom Events" data-label="Footer Terms link" > Terms </a> </li> <li class="devsite-footer-utility-item "> <a class="devsite-footer-utility-link gc-analytics-event" href="//policies.google.com/privacy" data-category="Site-Wide Custom Events" data-label="Footer Privacy link" > Privacy </a> </li> <li class="devsite-footer-utility-item glue-cookie-notification-bar-control"> <a class="devsite-footer-utility-link gc-analytics-event" href="#" data-category="Site-Wide Custom Events" data-label="Footer Manage cookies link" aria-hidden="true" > Manage cookies </a> </li> <li class="devsite-footer-utility-item devsite-footer-utility-button"> <span class="devsite-footer-utility-description">Sign up for the TensorFlow newsletter</span> <a class="devsite-footer-utility-link gc-analytics-event" href="//www.tensorflow.org/subscribe" data-category="Site-Wide Custom Events" data-label="Footer Subscribe link" > Subscribe </a> </li> </ul> <devsite-language-selector> <ul role="presentation"> <li role="presentation"> <a role="menuitem" lang="en" >English</a> </li> <li role="presentation"> <a role="menuitem" lang="es_419" >Español – América Latina</a> </li> <li role="presentation"> <a role="menuitem" lang="fr" >Français</a> </li> <li role="presentation"> <a role="menuitem" lang="id" >Indonesia</a> </li> <li role="presentation"> <a role="menuitem" lang="it" >Italiano</a> </li> <li role="presentation"> <a role="menuitem" lang="pl" >Polski</a> </li> <li role="presentation"> <a role="menuitem" lang="pt_br" >Português – Brasil</a> </li> <li role="presentation"> <a role="menuitem" lang="vi" >Tiếng Việt</a> </li> <li role="presentation"> <a role="menuitem" lang="tr" >Türkçe</a> </li> <li role="presentation"> <a role="menuitem" lang="ru" >Русский</a> </li> <li role="presentation"> <a role="menuitem" lang="he" >עברית</a> </li> <li role="presentation"> <a role="menuitem" lang="ar" >العربيّة</a> </li> <li role="presentation"> <a role="menuitem" lang="fa" >فارسی</a> </li> <li role="presentation"> <a role="menuitem" lang="hi" >हिंदी</a> </li> <li role="presentation"> <a role="menuitem" lang="bn" >বাংলা</a> </li> <li role="presentation"> <a role="menuitem" lang="th" >ภาษาไทย</a> </li> <li role="presentation"> <a role="menuitem" lang="zh_cn" >中文 – 简体</a> </li> <li role="presentation"> <a role="menuitem" lang="ja" >日本語</a> </li> <li role="presentation"> <a role="menuitem" lang="ko" >한국어</a> </li> </ul> </devsite-language-selector> </nav> </div> </devsite-footer-utility> <devsite-panel></devsite-panel> </section></section> <devsite-sitemask></devsite-sitemask> <devsite-snackbar></devsite-snackbar> <devsite-tooltip ></devsite-tooltip> <devsite-heading-link></devsite-heading-link> <devsite-analytics> <script type="application/json" analytics>[{"dimensions": {"dimension4": "Neural Structured Learning", "dimension5": "en", "dimension1": "Signed out", "dimension12": false, "dimension6": "en", "dimension3": false}, "gaid": "UA-69864048-1", "metrics": {"ratings_count": "metric2", "ratings_value": "metric1"}, "purpose": 0}]</script> <script type="application/json" tag-management>{"at": "True", "ga4": [], "ga4p": [], "gtm": [{"id": "GTM-MXSL34P", "purpose": 0}], "parameters": {"internalUser": "False", "language": {"machineTranslated": "False", "requested": "en", "served": "en"}, "pageType": "article", "projectName": "Neural Structured Learning", "signedIn": "False", "tenant": "tensorflow", "recommendations": {"sourcePage": "", "sourceType": 0, "sourceRank": 0, "sourceIdenticalDescriptions": 0, "sourceTitleWords": 0, "sourceDescriptionWords": 0, "experiment": ""}, "experiment": {"ids": ""}}}</script> </devsite-analytics> <devsite-badger></devsite-badger> <script nonce="zHNFUF1GxjiSkniBNYPTjnCkz3qpkI"> (function(d,e,v,s,i,t,E){d['GoogleDevelopersObject']=i; t=e.createElement(v);t.async=1;t.src=s;E=e.getElementsByTagName(v)[0]; E.parentNode.insertBefore(t,E);})(window, document, 'script', 'https://www.gstatic.com/devrel-devsite/prod/v870e399c64f7c43c99a3043db4b3a74327bb93d0914e84a0c3dba90bbfd67625/tensorflow/js/app_loader.js', '[15,"en",null,"/js/devsite_app_module.js","https://www.gstatic.com/devrel-devsite/prod/v870e399c64f7c43c99a3043db4b3a74327bb93d0914e84a0c3dba90bbfd67625","https://www.gstatic.com/devrel-devsite/prod/v870e399c64f7c43c99a3043db4b3a74327bb93d0914e84a0c3dba90bbfd67625/tensorflow","https://tensorflow-dot-devsite-v2-prod-3p.appspot.com",null,null,["/_pwa/tensorflow/manifest.json","https://www.gstatic.com/devrel-devsite/prod/v870e399c64f7c43c99a3043db4b3a74327bb93d0914e84a0c3dba90bbfd67625/images/video-placeholder.svg","https://www.gstatic.com/devrel-devsite/prod/v870e399c64f7c43c99a3043db4b3a74327bb93d0914e84a0c3dba90bbfd67625/tensorflow/images/favicon.png","https://www.gstatic.com/devrel-devsite/prod/v870e399c64f7c43c99a3043db4b3a74327bb93d0914e84a0c3dba90bbfd67625/tensorflow/images/lockup.svg","https://fonts.googleapis.com/css?family=Google+Sans:400,500|Roboto:400,400italic,500,500italic,700,700italic|Roboto+Mono:400,500,700&display=swap"],1,null,[1,6,8,12,14,17,21,25,50,52,63,70,75,76,80,87,91,92,93,97,98,100,101,102,103,104,105,107,108,109,110,112,113,117,118,120,122,124,125,126,127,129,130,131,132,133,134,135,136,138,140,141,147,148,149,151,152,156,157,158,159,161,163,164,168,169,170,179,180,182,183,186,191,193,196],"AIzaSyCNm9YxQumEXwGJgTDjxoxXK6m1F-9720Q","AIzaSyCc76DZePGtoyUjqKrLdsMGk_ry7sljLbY","www.tensorflow.org","AIzaSyB9bqgQ2t11WJsOX8qNsCQ6U-w91mmqF-I","AIzaSyAdYnStPdzjcJJtQ0mvIaeaMKj7_t6J_Fg",null,null,null,["MiscFeatureFlags__developers_footer_dark_image","Profiles__enable_release_notes_notifications","Profiles__enable_awarding_url","Profiles__enable_profile_collections","DevPro__enable_developer_subscriptions","Cloud__enable_cloud_shell","Profiles__enable_public_developer_profiles","Analytics__enable_clearcut_logging","Cloud__enable_llm_concierge_chat","SignIn__enable_refresh_access_tokens","Cloud__enable_legacy_calculator_redirect","Profiles__enable_completecodelab_endpoint","Profiles__require_profile_eligibility_for_signin","TpcFeatures__enable_required_headers","Profiles__enable_complete_playlist_endpoint","MiscFeatureFlags__emergency_css","Cloud__enable_free_trial_server_call","EngEduTelemetry__enable_engedu_telemetry","BookNav__enable_tenant_cache_key","MiscFeatureFlags__developers_footer_image","Cloud__enable_cloudx_experiment_ids","MiscFeatureFlags__enable_firebase_utm","TpcFeatures__enable_mirror_tenant_redirects","Profiles__enable_dashboard_curated_recommendations","Concierge__enable_pushui","MiscFeatureFlags__enable_variable_operator","CloudShell__cloud_shell_button","Profiles__enable_recognition_badges","Cloud__enable_cloud_facet_chat","Search__enable_ai_eligibility_checks","Profiles__enable_developer_profiles_callout","Experiments__reqs_query_experiments","DevPro__enable_cloud_innovators_plus","Search__enable_page_map","MiscFeatureFlags__enable_view_transitions","Cloud__enable_cloudx_ping","Cloud__enable_cloud_dlp_service","MiscFeatureFlags__enable_explain_this_code","Cloud__enable_cloud_shell_fte_user_flow","MiscFeatureFlags__enable_project_variables","Profiles__enable_page_saving","CloudShell__cloud_code_overflow_menu","Search__enable_suggestions_from_borg","Search__enable_dynamic_content_confidential_banner"],null,null,"AIzaSyA58TaKli1DculwmAmbpzLVGuWc8eCQgQc","https://developerscontentserving-pa.googleapis.com","AIzaSyDWBU60w0P9hEkr29kkksYs8Z7gvZ8u_wc","https://developerscontentsearch-pa.googleapis.com",2,4,null,"https://developerprofiles-pa.googleapis.com",[15,"tensorflow","TensorFlow","www.tensorflow.org",null,"tensorflow-dot-devsite-v2-prod-3p.appspot.com",null,null,[null,1,null,null,null,null,null,null,null,null,null,[1],null,null,null,null,null,null,[1],[1,null,null,[1]],null,null,null,[1,null,1],[1,1,null,1,1]],null,[25,null,null,null,null,null,"/images/lockup.svg","/images/logo.png",null,null,null,1,1,null,null,null,null,null,null,null,null,1,null,null,null,null,[]],[],null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,null,[6,1],null,[[],[1,1]],[[["UA-69864048-1"],["UA-69864048-4"],null,null,["UA-69864048-5"],["GTM-MXSL34P"],null,null,[["UA-69864048-1",1]],null,[["UA-69864048-5",1]],[["GTM-MXSL34P",1]],1],[[3,2],[5,4],[4,3],[1,1],[6,5],[12,8]],[[2,2],[1,1]]],null,4]]') </script> <devsite-a11y-announce></devsite-a11y-announce> </body> </html>