CINXE.COM
Multi-Modal and Multi-Task - AutoKeras
<!doctype html> <html lang="en" class="no-js"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width,initial-scale=1"> <meta name="description" content="Documentation for AutoKeras."> <link rel="canonical" href="http://autokeras.com/tutorial/multi/"> <link rel="prev" href="../text_regression/"> <link rel="next" href="../customized/"> <link rel="icon" href="/img/favicon.png"> <meta name="generator" content="mkdocs-1.5.3, mkdocs-material-9.5.14"> <title>Multi-Modal and Multi-Task - AutoKeras</title> <link rel="stylesheet" href="../../assets/stylesheets/main.10ba22f1.min.css"> <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin> <link rel="stylesheet" href="https://fonts.googleapis.com/css?family=Roboto:300,300i,400,400i,700,700i%7CRoboto+Mono:400,400i,700,700i&display=fallback"> <style>:root{--md-text-font:"Roboto";--md-code-font:"Roboto Mono"}</style> <link rel="stylesheet" href="../../stylesheets/extra.css"> <script>__md_scope=new URL("../..",location),__md_hash=e=>[...e].reduce((e,_)=>(e<<5)-e+_.charCodeAt(0),0),__md_get=(e,_=localStorage,t=__md_scope)=>JSON.parse(_.getItem(t.pathname+"."+e)),__md_set=(e,_,t=localStorage,a=__md_scope)=>{try{t.setItem(a.pathname+"."+e,JSON.stringify(_))}catch(e){}}</script> <script id="__analytics">function __md_analytics(){function n(){dataLayer.push(arguments)}window.dataLayer=window.dataLayer||[],n("js",new Date),n("config","G-GTF9QP8DFD"),document.addEventListener("DOMContentLoaded",function(){document.forms.search&&document.forms.search.query.addEventListener("blur",function(){this.value&&n("event","search",{search_term:this.value})}),document$.subscribe(function(){var a=document.forms.feedback;if(void 0!==a)for(var e of a.querySelectorAll("[type=submit]"))e.addEventListener("click",function(e){e.preventDefault();var t=document.location.pathname,e=this.getAttribute("data-md-value");n("event","feedback",{page:t,data:e}),a.firstElementChild.disabled=!0;e=a.querySelector(".md-feedback__note [data-md-value='"+e+"']");e&&(e.hidden=!1)}),a.hidden=!1}),location$.subscribe(function(e){n("config","G-GTF9QP8DFD",{page_path:e.pathname})})});var e=document.createElement("script");e.async=!0,e.src="https://www.googletagmanager.com/gtag/js?id=G-GTF9QP8DFD",document.getElementById("__analytics").insertAdjacentElement("afterEnd",e)}</script> <script>"undefined"!=typeof __md_analytics&&__md_analytics()</script> </head> <body dir="ltr"> <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer" autocomplete="off"> <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search" autocomplete="off"> <label class="md-overlay" for="__drawer"></label> <div data-md-component="skip"> <a href="#what-is-multi-modal" class="md-skip"> Skip to content </a> </div> <div data-md-component="announce"> </div> <header class="md-header md-header--shadow" data-md-component="header"> <nav class="md-header__inner md-grid" aria-label="Header"> <a href="../.." title="AutoKeras" class="md-header__button md-logo" aria-label="AutoKeras" data-md-component="logo"> <img src="/img/logo_white.svg" alt="logo"> </a> <label class="md-header__button md-icon" for="__drawer"> <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M3 6h18v2H3V6m0 5h18v2H3v-2m0 5h18v2H3v-2Z"/></svg> </label> <div class="md-header__title" data-md-component="header-title"> <div class="md-header__ellipsis"> <div class="md-header__topic"> <span class="md-ellipsis"> AutoKeras </span> </div> <div class="md-header__topic" data-md-component="header-topic"> <span class="md-ellipsis"> Multi-Modal and Multi-Task </span> </div> </div> </div> <script>var media,input,key,value,palette=__md_get("__palette");if(palette&&palette.color){"(prefers-color-scheme)"===palette.color.media&&(media=matchMedia("(prefers-color-scheme: light)"),input=document.querySelector(media.matches?"[data-md-color-media='(prefers-color-scheme: light)']":"[data-md-color-media='(prefers-color-scheme: dark)']"),palette.color.media=input.getAttribute("data-md-color-media"),palette.color.scheme=input.getAttribute("data-md-color-scheme"),palette.color.primary=input.getAttribute("data-md-color-primary"),palette.color.accent=input.getAttribute("data-md-color-accent"));for([key,value]of Object.entries(palette.color))document.body.setAttribute("data-md-color-"+key,value)}</script> <label class="md-header__button md-icon" for="__search"> <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M9.5 3A6.5 6.5 0 0 1 16 9.5c0 1.61-.59 3.09-1.56 4.23l.27.27h.79l5 5-1.5 1.5-5-5v-.79l-.27-.27A6.516 6.516 0 0 1 9.5 16 6.5 6.5 0 0 1 3 9.5 6.5 6.5 0 0 1 9.5 3m0 2C7 5 5 7 5 9.5S7 14 9.5 14 14 12 14 9.5 12 5 9.5 5Z"/></svg> </label> <div class="md-search" data-md-component="search" role="dialog"> <label class="md-search__overlay" for="__search"></label> <div class="md-search__inner" role="search"> <form class="md-search__form" name="search"> <input type="text" class="md-search__input" name="query" aria-label="Search" placeholder="Search" autocapitalize="off" autocorrect="off" autocomplete="off" spellcheck="false" data-md-component="search-query" required> <label class="md-search__icon md-icon" for="__search"> <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M9.5 3A6.5 6.5 0 0 1 16 9.5c0 1.61-.59 3.09-1.56 4.23l.27.27h.79l5 5-1.5 1.5-5-5v-.79l-.27-.27A6.516 6.516 0 0 1 9.5 16 6.5 6.5 0 0 1 3 9.5 6.5 6.5 0 0 1 9.5 3m0 2C7 5 5 7 5 9.5S7 14 9.5 14 14 12 14 9.5 12 5 9.5 5Z"/></svg> <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M20 11v2H8l5.5 5.5-1.42 1.42L4.16 12l7.92-7.92L13.5 5.5 8 11h12Z"/></svg> </label> <nav class="md-search__options" aria-label="Search"> <button type="reset" class="md-search__icon md-icon" title="Clear" aria-label="Clear" tabindex="-1"> <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M19 6.41 17.59 5 12 10.59 6.41 5 5 6.41 10.59 12 5 17.59 6.41 19 12 13.41 17.59 19 19 17.59 13.41 12 19 6.41Z"/></svg> </button> </nav> </form> <div class="md-search__output"> <div class="md-search__scrollwrap" data-md-scrollfix> <div class="md-search-result" data-md-component="search-result"> <div class="md-search-result__meta"> Initializing search </div> <ol class="md-search-result__list" role="presentation"></ol> </div> </div> </div> </div> </div> <div class="md-header__source"> <a href="https://github.com/keras-team/autokeras" title="Go to repository" class="md-source" data-md-component="source"> <div class="md-source__icon md-icon"> <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 448 512"><!--! Font Awesome Free 6.5.1 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license/free (Icons: CC BY 4.0, Fonts: SIL OFL 1.1, Code: MIT License) Copyright 2023 Fonticons, Inc.--><path d="M439.55 236.05 244 40.45a28.87 28.87 0 0 0-40.81 0l-40.66 40.63 51.52 51.52c27.06-9.14 52.68 16.77 43.39 43.68l49.66 49.66c34.23-11.8 61.18 31 35.47 56.69-26.49 26.49-70.21-2.87-56-37.34L240.22 199v121.85c25.3 12.54 22.26 41.85 9.08 55a34.34 34.34 0 0 1-48.55 0c-17.57-17.6-11.07-46.91 11.25-56v-123c-20.8-8.51-24.6-30.74-18.64-45L142.57 101 8.45 235.14a28.86 28.86 0 0 0 0 40.81l195.61 195.6a28.86 28.86 0 0 0 40.8 0l194.69-194.69a28.86 28.86 0 0 0 0-40.81z"/></svg> </div> <div class="md-source__repository"> GitHub </div> </a> </div> </nav> </header> <div class="md-container" data-md-component="container"> <main class="md-main" data-md-component="main"> <div class="md-main__inner md-grid"> <div class="md-sidebar md-sidebar--primary" data-md-component="sidebar" data-md-type="navigation" > <div class="md-sidebar__scrollwrap"> <div class="md-sidebar__inner"> <nav class="md-nav md-nav--primary" aria-label="Navigation" data-md-level="0"> <label class="md-nav__title" for="__drawer"> <a href="../.." title="AutoKeras" class="md-nav__button md-logo" aria-label="AutoKeras" data-md-component="logo"> <img src="/img/logo_white.svg" alt="logo"> </a> AutoKeras </label> <div class="md-nav__source"> <a href="https://github.com/keras-team/autokeras" title="Go to repository" class="md-source" data-md-component="source"> <div class="md-source__icon md-icon"> <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 448 512"><!--! Font Awesome Free 6.5.1 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license/free (Icons: CC BY 4.0, Fonts: SIL OFL 1.1, Code: MIT License) Copyright 2023 Fonticons, Inc.--><path d="M439.55 236.05 244 40.45a28.87 28.87 0 0 0-40.81 0l-40.66 40.63 51.52 51.52c27.06-9.14 52.68 16.77 43.39 43.68l49.66 49.66c34.23-11.8 61.18 31 35.47 56.69-26.49 26.49-70.21-2.87-56-37.34L240.22 199v121.85c25.3 12.54 22.26 41.85 9.08 55a34.34 34.34 0 0 1-48.55 0c-17.57-17.6-11.07-46.91 11.25-56v-123c-20.8-8.51-24.6-30.74-18.64-45L142.57 101 8.45 235.14a28.86 28.86 0 0 0 0 40.81l195.61 195.6a28.86 28.86 0 0 0 40.8 0l194.69-194.69a28.86 28.86 0 0 0 0-40.81z"/></svg> </div> <div class="md-source__repository"> GitHub </div> </a> </div> <ul class="md-nav__list" data-md-scrollfix> <li class="md-nav__item"> <a href="../.." class="md-nav__link"> <span class="md-ellipsis"> Home </span> </a> </li> <li class="md-nav__item"> <a href="../../install/" class="md-nav__link"> <span class="md-ellipsis"> Installation </span> </a> </li> <li class="md-nav__item md-nav__item--active md-nav__item--nested"> <input class="md-nav__toggle md-toggle " type="checkbox" id="__nav_3" checked> <label class="md-nav__link" for="__nav_3" id="__nav_3_label" tabindex="0"> <span class="md-ellipsis"> Tutorials </span> <span class="md-nav__icon md-icon"></span> </label> <nav class="md-nav" data-md-level="1" aria-labelledby="__nav_3_label" aria-expanded="true"> <label class="md-nav__title" for="__nav_3"> <span class="md-nav__icon md-icon"></span> Tutorials </label> <ul class="md-nav__list" data-md-scrollfix> <li class="md-nav__item"> <a href="../overview/" class="md-nav__link"> <span class="md-ellipsis"> Overview </span> </a> </li> <li class="md-nav__item"> <a href="../image_classification/" class="md-nav__link"> <span class="md-ellipsis"> Image Classification </span> </a> </li> <li class="md-nav__item"> <a href="../image_regression/" class="md-nav__link"> <span class="md-ellipsis"> Image Regression </span> </a> </li> <li class="md-nav__item"> <a href="../text_classification/" class="md-nav__link"> <span class="md-ellipsis"> Text Classification </span> </a> </li> <li class="md-nav__item"> <a href="../text_regression/" class="md-nav__link"> <span class="md-ellipsis"> Text Regression </span> </a> </li> <li class="md-nav__item md-nav__item--active"> <input class="md-nav__toggle md-toggle" type="checkbox" id="__toc"> <label class="md-nav__link md-nav__link--active" for="__toc"> <span class="md-ellipsis"> Multi-Modal and Multi-Task </span> <span class="md-nav__icon md-icon"></span> </label> <a href="./" class="md-nav__link md-nav__link--active"> <span class="md-ellipsis"> Multi-Modal and Multi-Task </span> </a> <nav class="md-nav md-nav--secondary" aria-label="Table of contents"> <label class="md-nav__title" for="__toc"> <span class="md-nav__icon md-icon"></span> Table of contents </label> <ul class="md-nav__list" data-md-component="toc" data-md-scrollfix> <li class="md-nav__item"> <a href="#what-is-multi-modal" class="md-nav__link"> <span class="md-ellipsis"> What is multi-modal? </span> </a> </li> <li class="md-nav__item"> <a href="#what-is-multi-task" class="md-nav__link"> <span class="md-ellipsis"> What is multi-task? </span> </a> </li> <li class="md-nav__item"> <a href="#data-preparation" class="md-nav__link"> <span class="md-ellipsis"> Data Preparation </span> </a> </li> <li class="md-nav__item"> <a href="#build-and-train-the-model" class="md-nav__link"> <span class="md-ellipsis"> Build and Train the Model </span> </a> </li> <li class="md-nav__item"> <a href="#validation-data" class="md-nav__link"> <span class="md-ellipsis"> Validation Data </span> </a> </li> <li class="md-nav__item"> <a href="#customized-search-space" class="md-nav__link"> <span class="md-ellipsis"> Customized Search Space </span> </a> </li> <li class="md-nav__item"> <a href="#data-format" class="md-nav__link"> <span class="md-ellipsis"> Data Format </span> </a> </li> <li class="md-nav__item"> <a href="#reference" class="md-nav__link"> <span class="md-ellipsis"> Reference </span> </a> </li> </ul> </nav> </li> <li class="md-nav__item"> <a href="../customized/" class="md-nav__link"> <span class="md-ellipsis"> Customized Model </span> </a> </li> <li class="md-nav__item"> <a href="../export/" class="md-nav__link"> <span class="md-ellipsis"> Export Model </span> </a> </li> <li class="md-nav__item"> <a href="../load/" class="md-nav__link"> <span class="md-ellipsis"> Load Data from Disk </span> </a> </li> <li class="md-nav__item"> <a href="../faq/" class="md-nav__link"> <span class="md-ellipsis"> FAQ </span> </a> </li> </ul> </nav> </li> <li class="md-nav__item md-nav__item--nested"> <input class="md-nav__toggle md-toggle " type="checkbox" id="__nav_4" > <label class="md-nav__link" for="__nav_4" id="__nav_4_label" tabindex="0"> <span class="md-ellipsis"> Extensions </span> <span class="md-nav__icon md-icon"></span> </label> <nav class="md-nav" data-md-level="1" aria-labelledby="__nav_4_label" aria-expanded="false"> <label class="md-nav__title" for="__nav_4"> <span class="md-nav__icon md-icon"></span> Extensions </label> <ul class="md-nav__list" data-md-scrollfix> <li class="md-nav__item"> <a href="../../extensions/tf_cloud/" class="md-nav__link"> <span class="md-ellipsis"> TensorFlow Cloud </span> </a> </li> <li class="md-nav__item"> <a href="../../extensions/trains/" class="md-nav__link"> <span class="md-ellipsis"> TRAINS </span> </a> </li> </ul> </nav> </li> <li class="md-nav__item"> <a href="../../docker/" class="md-nav__link"> <span class="md-ellipsis"> Docker </span> </a> </li> <li class="md-nav__item"> <a href="../../contributing/" class="md-nav__link"> <span class="md-ellipsis"> Contributing Guide </span> </a> </li> <li class="md-nav__item md-nav__item--nested"> <input class="md-nav__toggle md-toggle " type="checkbox" id="__nav_7" > <label class="md-nav__link" for="__nav_7" id="__nav_7_label" tabindex="0"> <span class="md-ellipsis"> Documentation </span> <span class="md-nav__icon md-icon"></span> </label> <nav class="md-nav" data-md-level="1" aria-labelledby="__nav_7_label" aria-expanded="false"> <label class="md-nav__title" for="__nav_7"> <span class="md-nav__icon md-icon"></span> Documentation </label> <ul class="md-nav__list" data-md-scrollfix> <li class="md-nav__item"> <a href="../../image_classifier/" class="md-nav__link"> <span class="md-ellipsis"> ImageClassifier </span> </a> </li> <li class="md-nav__item"> <a href="../../image_regressor/" class="md-nav__link"> <span class="md-ellipsis"> ImageRegressor </span> </a> </li> <li class="md-nav__item"> <a href="../../text_classifier/" class="md-nav__link"> <span class="md-ellipsis"> TextClassifier </span> </a> </li> <li class="md-nav__item"> <a href="../../text_regressor/" class="md-nav__link"> <span class="md-ellipsis"> TextRegressor </span> </a> </li> <li class="md-nav__item"> <a href="../../auto_model/" class="md-nav__link"> <span class="md-ellipsis"> AutoModel </span> </a> </li> <li class="md-nav__item"> <a href="../../base/" class="md-nav__link"> <span class="md-ellipsis"> Base Class </span> </a> </li> <li class="md-nav__item"> <a href="../../node/" class="md-nav__link"> <span class="md-ellipsis"> Node </span> </a> </li> <li class="md-nav__item"> <a href="../../block/" class="md-nav__link"> <span class="md-ellipsis"> Block </span> </a> </li> <li class="md-nav__item"> <a href="../../utils/" class="md-nav__link"> <span class="md-ellipsis"> Utils </span> </a> </li> </ul> </nav> </li> <li class="md-nav__item"> <a href="../../benchmarks/" class="md-nav__link"> <span class="md-ellipsis"> Benchmarks </span> </a> </li> <li class="md-nav__item"> <a href="../../about/" class="md-nav__link"> <span class="md-ellipsis"> About </span> </a> </li> </ul> </nav> </div> </div> </div> <div class="md-sidebar md-sidebar--secondary" data-md-component="sidebar" data-md-type="toc" > <div class="md-sidebar__scrollwrap"> <div class="md-sidebar__inner"> <nav class="md-nav md-nav--secondary" aria-label="Table of contents"> <label class="md-nav__title" for="__toc"> <span class="md-nav__icon md-icon"></span> Table of contents </label> <ul class="md-nav__list" data-md-component="toc" data-md-scrollfix> <li class="md-nav__item"> <a href="#what-is-multi-modal" class="md-nav__link"> <span class="md-ellipsis"> What is multi-modal? </span> </a> </li> <li class="md-nav__item"> <a href="#what-is-multi-task" class="md-nav__link"> <span class="md-ellipsis"> What is multi-task? </span> </a> </li> <li class="md-nav__item"> <a href="#data-preparation" class="md-nav__link"> <span class="md-ellipsis"> Data Preparation </span> </a> </li> <li class="md-nav__item"> <a href="#build-and-train-the-model" class="md-nav__link"> <span class="md-ellipsis"> Build and Train the Model </span> </a> </li> <li class="md-nav__item"> <a href="#validation-data" class="md-nav__link"> <span class="md-ellipsis"> Validation Data </span> </a> </li> <li class="md-nav__item"> <a href="#customized-search-space" class="md-nav__link"> <span class="md-ellipsis"> Customized Search Space </span> </a> </li> <li class="md-nav__item"> <a href="#data-format" class="md-nav__link"> <span class="md-ellipsis"> Data Format </span> </a> </li> <li class="md-nav__item"> <a href="#reference" class="md-nav__link"> <span class="md-ellipsis"> Reference </span> </a> </li> </ul> </nav> </div> </div> </div> <div class="md-content" data-md-component="content"> <article class="md-content__inner md-typeset"> <h1>Multi-Modal and Multi-Task</h1> <p><span class="twemoji"><svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M3.9 12c0-1.71 1.39-3.1 3.1-3.1h4V7H7a5 5 0 0 0-5 5 5 5 0 0 0 5 5h4v-1.9H7c-1.71 0-3.1-1.39-3.1-3.1M8 13h8v-2H8v2m9-6h-4v1.9h4c1.71 0 3.1 1.39 3.1 3.1 0 1.71-1.39 3.1-3.1 3.1h-4V17h4a5 5 0 0 0 5-5 5 5 0 0 0-5-5Z"/></svg></span> <a href="https://colab.research.google.com/github/keras-team/autokeras/blob/master/docs/ipynb/multi.ipynb"><strong>View in Colab</strong></a> <span class="twemoji"><svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 16 16"><path d="M8 0c4.42 0 8 3.58 8 8a8.013 8.013 0 0 1-5.45 7.59c-.4.08-.55-.17-.55-.38 0-.27.01-1.13.01-2.2 0-.75-.25-1.23-.54-1.48 1.78-.2 3.65-.88 3.65-3.95 0-.88-.31-1.59-.82-2.15.08-.2.36-1.02-.08-2.12 0 0-.67-.22-2.2.82-.64-.18-1.32-.27-2-.27-.68 0-1.36.09-2 .27-1.53-1.03-2.2-.82-2.2-.82-.44 1.1-.16 1.92-.08 2.12-.51.56-.82 1.28-.82 2.15 0 3.06 1.86 3.75 3.64 3.95-.23.2-.44.55-.51 1.07-.46.21-1.61.55-2.33-.66-.15-.24-.6-.83-1.23-.82-.67.01-.27.38.01.53.34.19.73.9.82 1.13.16.45.68 1.31 2.69.94 0 .67.01 1.3.01 1.49 0 .21-.15.45-.55.38A7.995 7.995 0 0 1 0 8c0-4.42 3.58-8 8-8Z"/></svg></span> <a href="https://github.com/keras-team/autokeras/blob/master/docs/py/multi.py"><strong>GitHub source</strong></a></p> <div class="highlight"><pre><span></span><code><span class="err">!</span><span class="n">pip</span> <span class="n">install</span> <span class="n">autokeras</span> </code></pre></div> <div class="highlight"><pre><span></span><code><span class="kn">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span> <span class="kn">import</span> <span class="nn">autokeras</span> <span class="k">as</span> <span class="nn">ak</span> </code></pre></div> <p>In this tutorial we are making use of the <a href="/auto_model/#automodel-class">AutoModel</a> API to show how to handle multi-modal data and multi-task.</p> <h2 id="what-is-multi-modal">What is multi-modal?</h2> <p>Multi-modal data means each data instance has multiple forms of information. For example, a photo can be saved as a image. Besides the image, it may also have when and where it was taken as its attributes, which can be represented as numerical data.</p> <h2 id="what-is-multi-task">What is multi-task?</h2> <p>Multi-task here we refer to we want to predict multiple targets with the same input features. For example, we not only want to classify an image according to its content, but we also want to regress its quality as a float number between 0 and 1.</p> <p>The following diagram shows an example of multi-modal and multi-task neural network model.</p> <div class="mermaid"> graph TD id1(ImageInput) --> id3(Some Neural Network Model) id2(Input) --> id3 id3 --> id4(ClassificationHead) id3 --> id5(RegressionHead) </div> <p>It has two inputs the images and the numerical input data. Each image is associated with a set of attributes in the numerical input data. From these data, we are trying to predict the classification label and the regression value at the same time.</p> <h2 id="data-preparation">Data Preparation</h2> <p>To illustrate our idea, we generate some random image and numerical data as the multi-modal data.</p> <div class="highlight"><pre><span></span><code><span class="n">num_instances</span> <span class="o">=</span> <span class="mi">10</span> <span class="c1"># Generate image data.</span> <span class="n">image_data</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">num_instances</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="c1"># Generate numerical data.</span> <span class="n">numerical_data</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">num_instances</span><span class="p">,</span> <span class="mi">20</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> </code></pre></div> <p>We also generate some multi-task targets for classification and regression.</p> <div class="highlight"><pre><span></span><code><span class="c1"># Generate regression targets.</span> <span class="n">regression_target</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">num_instances</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="c1"># Generate classification labels of five classes.</span> <span class="n">classification_target</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="n">num_instances</span><span class="p">)</span> </code></pre></div> <h2 id="build-and-train-the-model">Build and Train the Model</h2> <p>Then we initialize the multi-modal and multi-task model with <a href="/auto_model/#automodel-class">AutoModel</a>. Since this is just a demo, we use small amount of <code>max_trials</code> and <code>epochs</code>.</p> <div class="highlight"><pre><span></span><code><span class="c1"># Initialize the multi with multiple inputs and outputs.</span> <span class="n">model</span> <span class="o">=</span> <span class="n">ak</span><span class="o">.</span><span class="n">AutoModel</span><span class="p">(</span> <span class="n">inputs</span><span class="o">=</span><span class="p">[</span><span class="n">ak</span><span class="o">.</span><span class="n">ImageInput</span><span class="p">(),</span> <span class="n">ak</span><span class="o">.</span><span class="n">Input</span><span class="p">()],</span> <span class="n">outputs</span><span class="o">=</span><span class="p">[</span> <span class="n">ak</span><span class="o">.</span><span class="n">RegressionHead</span><span class="p">(</span><span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="s2">"mae"</span><span class="p">]),</span> <span class="n">ak</span><span class="o">.</span><span class="n">ClassificationHead</span><span class="p">(</span> <span class="n">loss</span><span class="o">=</span><span class="s2">"categorical_crossentropy"</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="s2">"accuracy"</span><span class="p">]</span> <span class="p">),</span> <span class="p">],</span> <span class="n">overwrite</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">max_trials</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="p">)</span> <span class="c1"># Fit the model with prepared data.</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span> <span class="p">[</span><span class="n">image_data</span><span class="p">,</span> <span class="n">numerical_data</span><span class="p">],</span> <span class="p">[</span><span class="n">regression_target</span><span class="p">,</span> <span class="n">classification_target</span><span class="p">],</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="p">)</span> </code></pre></div> <h2 id="validation-data">Validation Data</h2> <p>By default, AutoKeras use the last 20% of training data as validation data. As shown in the example below, you can use <code>validation_split</code> to specify the percentage.</p> <div class="highlight"><pre><span></span><code><span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span> <span class="p">[</span><span class="n">image_data</span><span class="p">,</span> <span class="n">numerical_data</span><span class="p">],</span> <span class="p">[</span><span class="n">regression_target</span><span class="p">,</span> <span class="n">classification_target</span><span class="p">],</span> <span class="c1"># Split the training data and use the last 15% as validation data.</span> <span class="n">validation_split</span><span class="o">=</span><span class="mf">0.15</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="p">)</span> </code></pre></div> <p>You can also use your own validation set instead of splitting it from the training data with <code>validation_data</code>.</p> <div class="highlight"><pre><span></span><code><span class="n">split</span> <span class="o">=</span> <span class="mi">5</span> <span class="n">image_val</span> <span class="o">=</span> <span class="n">image_data</span><span class="p">[</span><span class="n">split</span><span class="p">:]</span> <span class="n">numerical_val</span> <span class="o">=</span> <span class="n">numerical_data</span><span class="p">[</span><span class="n">split</span><span class="p">:]</span> <span class="n">regression_val</span> <span class="o">=</span> <span class="n">regression_target</span><span class="p">[</span><span class="n">split</span><span class="p">:]</span> <span class="n">classification_val</span> <span class="o">=</span> <span class="n">classification_target</span><span class="p">[</span><span class="n">split</span><span class="p">:]</span> <span class="n">image_data</span> <span class="o">=</span> <span class="n">image_data</span><span class="p">[:</span><span class="n">split</span><span class="p">]</span> <span class="n">numerical_data</span> <span class="o">=</span> <span class="n">numerical_data</span><span class="p">[:</span><span class="n">split</span><span class="p">]</span> <span class="n">regression_target</span> <span class="o">=</span> <span class="n">regression_target</span><span class="p">[:</span><span class="n">split</span><span class="p">]</span> <span class="n">classification_target</span> <span class="o">=</span> <span class="n">classification_target</span><span class="p">[:</span><span class="n">split</span><span class="p">]</span> <span class="n">model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span> <span class="p">[</span><span class="n">image_data</span><span class="p">,</span> <span class="n">numerical_data</span><span class="p">],</span> <span class="p">[</span><span class="n">regression_target</span><span class="p">,</span> <span class="n">classification_target</span><span class="p">],</span> <span class="c1"># Use your own validation set.</span> <span class="n">validation_data</span><span class="o">=</span><span class="p">(</span> <span class="p">[</span><span class="n">image_val</span><span class="p">,</span> <span class="n">numerical_val</span><span class="p">],</span> <span class="p">[</span><span class="n">regression_val</span><span class="p">,</span> <span class="n">classification_val</span><span class="p">],</span> <span class="p">),</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="p">)</span> </code></pre></div> <h2 id="customized-search-space">Customized Search Space</h2> <p>You can customize your search space. The following figure shows the search space we want to define.</p> <div class="mermaid"> graph LR id1(ImageInput) --> id2(Normalization) id2 --> id3(Image Augmentation) id3 --> id4(Convolutional) id3 --> id5(ResNet V2) id4 --> id6(Merge) id5 --> id6 id7(Input) --> id9(DenseBlock) id6 --> id10(Merge) id9 --> id10 id10 --> id11(Classification Head) id10 --> id12(Regression Head) </div> <div class="highlight"><pre><span></span><code><span class="n">input_node1</span> <span class="o">=</span> <span class="n">ak</span><span class="o">.</span><span class="n">ImageInput</span><span class="p">()</span> <span class="n">output_node</span> <span class="o">=</span> <span class="n">ak</span><span class="o">.</span><span class="n">Normalization</span><span class="p">()(</span><span class="n">input_node1</span><span class="p">)</span> <span class="n">output_node</span> <span class="o">=</span> <span class="n">ak</span><span class="o">.</span><span class="n">ImageAugmentation</span><span class="p">()(</span><span class="n">output_node</span><span class="p">)</span> <span class="n">output_node1</span> <span class="o">=</span> <span class="n">ak</span><span class="o">.</span><span class="n">ConvBlock</span><span class="p">()(</span><span class="n">output_node</span><span class="p">)</span> <span class="n">output_node2</span> <span class="o">=</span> <span class="n">ak</span><span class="o">.</span><span class="n">ResNetBlock</span><span class="p">(</span><span class="n">version</span><span class="o">=</span><span class="s2">"v2"</span><span class="p">)(</span><span class="n">output_node</span><span class="p">)</span> <span class="n">output_node1</span> <span class="o">=</span> <span class="n">ak</span><span class="o">.</span><span class="n">Merge</span><span class="p">()([</span><span class="n">output_node1</span><span class="p">,</span> <span class="n">output_node2</span><span class="p">])</span> <span class="n">input_node2</span> <span class="o">=</span> <span class="n">ak</span><span class="o">.</span><span class="n">Input</span><span class="p">()</span> <span class="n">output_node2</span> <span class="o">=</span> <span class="n">ak</span><span class="o">.</span><span class="n">DenseBlock</span><span class="p">()(</span><span class="n">input_node2</span><span class="p">)</span> <span class="n">output_node</span> <span class="o">=</span> <span class="n">ak</span><span class="o">.</span><span class="n">Merge</span><span class="p">()([</span><span class="n">output_node1</span><span class="p">,</span> <span class="n">output_node2</span><span class="p">])</span> <span class="n">output_node1</span> <span class="o">=</span> <span class="n">ak</span><span class="o">.</span><span class="n">ClassificationHead</span><span class="p">()(</span><span class="n">output_node</span><span class="p">)</span> <span class="n">output_node2</span> <span class="o">=</span> <span class="n">ak</span><span class="o">.</span><span class="n">RegressionHead</span><span class="p">()(</span><span class="n">output_node</span><span class="p">)</span> <span class="n">auto_model</span> <span class="o">=</span> <span class="n">ak</span><span class="o">.</span><span class="n">AutoModel</span><span class="p">(</span> <span class="n">inputs</span><span class="o">=</span><span class="p">[</span><span class="n">input_node1</span><span class="p">,</span> <span class="n">input_node2</span><span class="p">],</span> <span class="n">outputs</span><span class="o">=</span><span class="p">[</span><span class="n">output_node1</span><span class="p">,</span> <span class="n">output_node2</span><span class="p">],</span> <span class="n">overwrite</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">max_trials</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="p">)</span> <span class="n">image_data</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">num_instances</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">32</span><span class="p">,</span> <span class="mi">3</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="n">numerical_data</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">num_instances</span><span class="p">,</span> <span class="mi">20</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="n">regression_target</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">num_instances</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span> <span class="n">classification_target</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">random</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">5</span><span class="p">,</span> <span class="n">size</span><span class="o">=</span><span class="n">num_instances</span><span class="p">)</span> <span class="n">auto_model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span> <span class="p">[</span><span class="n">image_data</span><span class="p">,</span> <span class="n">numerical_data</span><span class="p">],</span> <span class="p">[</span><span class="n">classification_target</span><span class="p">,</span> <span class="n">regression_target</span><span class="p">],</span> <span class="n">batch_size</span><span class="o">=</span><span class="mi">3</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="p">)</span> </code></pre></div> <h2 id="data-format">Data Format</h2> <p>You can refer to the documentation of <a href="/node/#imageinput-class">ImageInput</a>, <a href="/node/#input-class">Input</a>, <a href="/node/#textinput-class">TextInput</a>, <a href="/block/#regressionhead-class">RegressionHead</a>, <a href="/block/#classificationhead-class">ClassificationHead</a>, for the format of different types of data. You can also refer to the Data Format section of the tutorials of <a href="/tutorial/image_classification/#data-format">Image Classification</a>, <a href="/tutorial/text_classification/#data-format">Text Classification</a>,</p> <h2 id="reference">Reference</h2> <p><a href="/auto_model/#automodel-class">AutoModel</a>, <a href="/node/#imageinput-class">ImageInput</a>, <a href="/node/#input-class">Input</a>, <a href="/block/#denseblock-class">DenseBlock</a>, <a href="/block/#regressionhead-class">RegressionHead</a>, <a href="/block/#classificationhead-class">ClassificationHead</a>, <a href="/block/#categoricaltonumerical-class">CategoricalToNumerical</a>.</p> </article> </div> <script>var target=document.getElementById(location.hash.slice(1));target&&target.name&&(target.checked=target.name.startsWith("__tabbed_"))</script> </div> </main> <footer class="md-footer"> <div class="md-footer-meta md-typeset"> <div class="md-footer-meta__inner md-grid"> <div class="md-copyright"> Made with <a href="https://squidfunk.github.io/mkdocs-material/" target="_blank" rel="noopener"> Material for MkDocs </a> </div> </div> </div> </footer> </div> <div class="md-dialog" data-md-component="dialog"> <div class="md-dialog__inner md-typeset"></div> </div> <script id="__config" type="application/json">{"base": "../..", "features": [], "search": "../../assets/javascripts/workers/search.b8dbb3d2.min.js", "translations": {"clipboard.copied": "Copied to clipboard", "clipboard.copy": "Copy to clipboard", "search.result.more.one": "1 more on this page", "search.result.more.other": "# more on this page", "search.result.none": "No matching documents", "search.result.one": "1 matching document", "search.result.other": "# matching documents", "search.result.placeholder": "Type to start searching", "search.result.term.missing": "Missing", "select.version": "Select version"}}</script> <script src="../../assets/javascripts/bundle.bd41221c.min.js"></script> <script src="https://unpkg.com/mermaid@8.4.4/dist/mermaid.min.js"></script> </body> </html>