CINXE.COM

Large-scale multi-label text classification

<!DOCTYPE html> <html lang="en"> <head> <meta charset="utf-8"> <meta name="viewport" content="width=device-width, initial-scale=1"> <meta name="description" content="Keras documentation"> <meta name="author" content="Keras Team"> <link rel="shortcut icon" href="https://keras.io/img/favicon.ico"> <link rel="canonical" href="https://keras.io/examples/nlp/multi_label_classification/" /> <!-- Social --> <meta property="og:title" content="Keras documentation: Large-scale multi-label text classification"> <meta property="og:image" content="https://keras.io/img/logo-k-keras-wb.png"> <meta name="twitter:title" content="Keras documentation: Large-scale multi-label text classification"> <meta name="twitter:image" content="https://keras.io/img/k-keras-social.png"> <meta name="twitter:card" content="summary"> <title>Large-scale multi-label text classification</title> <!-- Bootstrap core CSS --> <link href="/css/bootstrap.min.css" rel="stylesheet"> <!-- Custom fonts for this template --> <link href="https://fonts.googleapis.com/css2?family=Open+Sans:wght@400;600;700;800&display=swap" rel="stylesheet"> <!-- Custom styles for this template --> <link href="/css/docs.css" rel="stylesheet"> <link href="/css/monokai.css" rel="stylesheet"> <!-- Google Tag Manager --> <script>(function(w,d,s,l,i){w[l]=w[l]||[];w[l].push({'gtm.start': new Date().getTime(),event:'gtm.js'});var f=d.getElementsByTagName(s)[0], j=d.createElement(s),dl=l!='dataLayer'?'&l='+l:'';j.async=true;j.src= 'https://www.googletagmanager.com/gtm.js?id='+i+dl;f.parentNode.insertBefore(j,f); })(window,document,'script','dataLayer','GTM-5DNGF4N'); </script> <script> (function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){ (i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o), m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m) })(window,document,'script','https://www.google-analytics.com/analytics.js','ga'); ga('create', 'UA-175165319-128', 'auto'); ga('send', 'pageview'); </script> <!-- End Google Tag Manager --> <script async defer src="https://buttons.github.io/buttons.js"></script> </head> <body> <!-- Google Tag Manager (noscript) --> <noscript><iframe src="https://www.googletagmanager.com/ns.html?id=GTM-5DNGF4N" height="0" width="0" style="display:none;visibility:hidden"></iframe></noscript> <!-- End Google Tag Manager (noscript) --> <div class='k-page'> <div class="k-nav" id="nav-menu"> <a href='/'><img src='/img/logo-small.png' class='logo-small' /></a> <div class="nav flex-column nav-pills" role="tablist" aria-orientation="vertical"> <a class="nav-link" href="/about/" role="tab" aria-selected="">About Keras</a> <a class="nav-link" href="/getting_started/" role="tab" aria-selected="">Getting started</a> <a class="nav-link" href="/guides/" role="tab" aria-selected="">Developer guides</a> <a class="nav-link active" href="/examples/" role="tab" aria-selected="">Code examples</a> <a class="nav-sublink" href="/examples/vision/">Computer Vision</a> <a class="nav-sublink active" href="/examples/nlp/">Natural Language Processing</a> <a class="nav-sublink2" href="/examples/nlp/text_classification_from_scratch/">Text classification from scratch</a> <a class="nav-sublink2" href="/examples/nlp/active_learning_review_classification/">Review Classification using Active Learning</a> <a class="nav-sublink2" href="/examples/nlp/fnet_classification_with_keras_hub/">Text Classification using FNet</a> <a class="nav-sublink2 active" href="/examples/nlp/multi_label_classification/">Large-scale multi-label text classification</a> <a class="nav-sublink2" href="/examples/nlp/text_classification_with_transformer/">Text classification with Transformer</a> <a class="nav-sublink2" href="/examples/nlp/text_classification_with_switch_transformer/">Text classification with Switch Transformer</a> <a class="nav-sublink2" href="/examples/nlp/tweet-classification-using-tfdf/">Text classification using Decision Forests and pretrained embeddings</a> <a class="nav-sublink2" href="/examples/nlp/pretrained_word_embeddings/">Using pre-trained word embeddings</a> <a class="nav-sublink2" href="/examples/nlp/bidirectional_lstm_imdb/">Bidirectional LSTM on IMDB</a> <a class="nav-sublink2" href="/examples/nlp/data_parallel_training_with_keras_hub/">Data Parallel Training with KerasHub and tf.distribute</a> <a class="nav-sublink2" href="/examples/nlp/neural_machine_translation_with_keras_hub/">English-to-Spanish translation with KerasHub</a> <a class="nav-sublink2" href="/examples/nlp/neural_machine_translation_with_transformer/">English-to-Spanish translation with a sequence-to-sequence Transformer</a> <a class="nav-sublink2" href="/examples/nlp/lstm_seq2seq/">Character-level recurrent sequence-to-sequence model</a> <a class="nav-sublink2" href="/examples/nlp/multimodal_entailment/">Multimodal entailment</a> <a class="nav-sublink2" href="/examples/nlp/ner_transformers/">Named Entity Recognition using Transformers</a> <a class="nav-sublink2" href="/examples/nlp/text_extraction_with_bert/">Text Extraction with BERT</a> <a class="nav-sublink2" href="/examples/nlp/addition_rnn/">Sequence to sequence learning for performing number addition</a> <a class="nav-sublink2" href="/examples/nlp/semantic_similarity_with_keras_hub/">Semantic Similarity with KerasHub</a> <a class="nav-sublink2" href="/examples/nlp/semantic_similarity_with_bert/">Semantic Similarity with BERT</a> <a class="nav-sublink2" href="/examples/nlp/sentence_embeddings_with_sbert/">Sentence embeddings using Siamese RoBERTa-networks</a> <a class="nav-sublink2" href="/examples/nlp/masked_language_modeling/">End-to-end Masked Language Modeling with BERT</a> <a class="nav-sublink2" href="/examples/nlp/abstractive_summarization_with_bart/">Abstractive Text Summarization with BART</a> <a class="nav-sublink2" href="/examples/nlp/pretraining_BERT/">Pretraining BERT with Hugging Face Transformers</a> <a class="nav-sublink2" href="/examples/nlp/parameter_efficient_finetuning_of_gpt2_with_lora/">Parameter-efficient fine-tuning of GPT-2 with LoRA</a> <a class="nav-sublink2" href="/examples/nlp/multiple_choice_task_with_transfer_learning/">MultipleChoice Task with Transfer Learning</a> <a class="nav-sublink2" href="/examples/nlp/question_answering/">Question Answering with Hugging Face Transformers</a> <a class="nav-sublink2" href="/examples/nlp/t5_hf_summarization/">Abstractive Summarization with Hugging Face Transformers</a> <a class="nav-sublink" href="/examples/structured_data/">Structured Data</a> <a class="nav-sublink" href="/examples/timeseries/">Timeseries</a> <a class="nav-sublink" href="/examples/generative/">Generative Deep Learning</a> <a class="nav-sublink" href="/examples/audio/">Audio Data</a> <a class="nav-sublink" href="/examples/rl/">Reinforcement Learning</a> <a class="nav-sublink" href="/examples/graph/">Graph Data</a> <a class="nav-sublink" href="/examples/keras_recipes/">Quick Keras Recipes</a> <a class="nav-link" href="/api/" role="tab" aria-selected="">Keras 3 API documentation</a> <a class="nav-link" href="/2.18/api/" role="tab" aria-selected="">Keras 2 API documentation</a> <a class="nav-link" href="/keras_tuner/" role="tab" aria-selected="">KerasTuner: Hyperparam Tuning</a> <a class="nav-link" href="/keras_hub/" role="tab" aria-selected="">KerasHub: Pretrained Models</a> </div> </div> <div class='k-main'> <div class='k-main-top'> <script> function displayDropdownMenu() { e = document.getElementById("nav-menu"); if (e.style.display == "block") { e.style.display = "none"; } else { e.style.display = "block"; document.getElementById("dropdown-nav").style.display = "block"; } } function resetMobileUI() { if (window.innerWidth <= 840) { document.getElementById("nav-menu").style.display = "none"; document.getElementById("dropdown-nav").style.display = "block"; } else { document.getElementById("nav-menu").style.display = "block"; document.getElementById("dropdown-nav").style.display = "none"; } var navmenu = document.getElementById("nav-menu"); var menuheight = navmenu.clientHeight; var kmain = document.getElementById("k-main-id"); kmain.style.minHeight = (menuheight + 100) + 'px'; } window.onresize = resetMobileUI; window.addEventListener("load", (event) => { resetMobileUI() }); </script> <div id='dropdown-nav' onclick="displayDropdownMenu();"> <svg viewBox="-20 -20 120 120" width="60" height="60"> <rect width="100" height="20"></rect> <rect y="30" width="100" height="20"></rect> <rect y="60" width="100" height="20"></rect> </svg> </div> <form class="bd-search d-flex align-items-center k-search-form" id="search-form"> <input type="search" class="k-search-input" id="search-input" placeholder="Search Keras documentation..." aria-label="Search Keras documentation..." autocomplete="off"> <button class="k-search-btn"> <svg width="13" height="13" viewBox="0 0 13 13"><title>search</title><path d="m4.8495 7.8226c0.82666 0 1.5262-0.29146 2.0985-0.87438 0.57232-0.58292 0.86378-1.2877 0.87438-2.1144 0.010599-0.82666-0.28086-1.5262-0.87438-2.0985-0.59352-0.57232-1.293-0.86378-2.0985-0.87438-0.8055-0.010599-1.5103 0.28086-2.1144 0.87438-0.60414 0.59352-0.8956 1.293-0.87438 2.0985 0.021197 0.8055 0.31266 1.5103 0.87438 2.1144 0.56172 0.60414 1.2665 0.8956 2.1144 0.87438zm4.4695 0.2115 3.681 3.6819-1.259 1.284-3.6817-3.7 0.0019784-0.69479-0.090043-0.098846c-0.87973 0.76087-1.92 1.1413-3.1207 1.1413-1.3553 0-2.5025-0.46363-3.4417-1.3909s-1.4088-2.0686-1.4088-3.4239c0-1.3553 0.4696-2.4966 1.4088-3.4239 0.9392-0.92727 2.0864-1.3969 3.4417-1.4088 1.3553-0.011889 2.4906 0.45771 3.406 1.4088 0.9154 0.95107 1.379 2.0924 1.3909 3.4239 0 1.2126-0.38043 2.2588-1.1413 3.1385l0.098834 0.090049z"></path></svg> </button> </form> <script> var form = document.getElementById('search-form'); form.onsubmit = function(e) { e.preventDefault(); var query = document.getElementById('search-input').value; window.location.href = '/search.html?query=' + query; return False } </script> </div> <div class='k-main-inner' id='k-main-id'> <div class='k-location-slug'> <span class="k-location-slug-pointer">►</span> <a href='/examples/'>Code examples</a> / <a href='/examples/nlp/'>Natural Language Processing</a> / Large-scale multi-label text classification </div> <div class='k-content'> <h1 id="largescale-multilabel-text-classification">Large-scale multi-label text classification</h1> <p><strong>Author:</strong> <a href="https://twitter.com/RisingSayak">Sayak Paul</a>, <a href="https://github.com/soumik12345">Soumik Rakshit</a><br> <strong>Date created:</strong> 2020/09/25<br> <strong>Last modified:</strong> 2020/12/23<br> <strong>Description:</strong> Implementing a large-scale multi-label text classification model.</p> <div class='example_version_banner keras_2'>ⓘ This example uses Keras 2</div> <p><img class="k-inline-icon" src="https://colab.research.google.com/img/colab_favicon.ico"/> <a href="https://colab.research.google.com/github/keras-team\keras-io\blob\master\examples\nlp/ipynb/multi_label_classification.ipynb"><strong>View in Colab</strong></a> <span class="k-dot">•</span><img class="k-inline-icon" src="https://github.com/favicon.ico"/> <a href="https://github.com/keras-team\keras-io\blob\master\examples\nlp/multi_label_classification.py"><strong>GitHub source</strong></a></p> <hr /> <h2 id="introduction">Introduction</h2> <p>In this example, we will build a multi-label text classifier to predict the subject areas of arXiv papers from their abstract bodies. This type of classifier can be useful for conference submission portals like <a href="https://openreview.net/">OpenReview</a>. Given a paper abstract, the portal could provide suggestions for which areas the paper would best belong to.</p> <p>The dataset was collected using the <a href="https://github.com/lukasschwab/arxiv.py"><code>arXiv</code> Python library</a> that provides a wrapper around the <a href="http://arxiv.org/help/api/index">original arXiv API</a>. To learn more about the data collection process, please refer to <a href="https://github.com/soumik12345/multi-label-text-classification/blob/master/arxiv_scrape.ipynb">this notebook</a>. Additionally, you can also find the dataset on <a href="https://www.kaggle.com/spsayakpaul/arxiv-paper-abstracts">Kaggle</a>.</p> <hr /> <h2 id="imports">Imports</h2> <div class="codehilite"><pre><span></span><code><span class="kn">from</span><span class="w"> </span><span class="nn">tensorflow.keras</span><span class="w"> </span><span class="kn">import</span> <span class="n">layers</span> <span class="kn">from</span><span class="w"> </span><span class="nn">tensorflow</span><span class="w"> </span><span class="kn">import</span> <span class="n">keras</span> <span class="kn">import</span><span class="w"> </span><span class="nn">tensorflow</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">tf</span> <span class="kn">from</span><span class="w"> </span><span class="nn">sklearn.model_selection</span><span class="w"> </span><span class="kn">import</span> <span class="n">train_test_split</span> <span class="kn">from</span><span class="w"> </span><span class="nn">ast</span><span class="w"> </span><span class="kn">import</span> <span class="n">literal_eval</span> <span class="kn">import</span><span class="w"> </span><span class="nn">matplotlib.pyplot</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">plt</span> <span class="kn">import</span><span class="w"> </span><span class="nn">pandas</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">pd</span> <span class="kn">import</span><span class="w"> </span><span class="nn">numpy</span><span class="w"> </span><span class="k">as</span><span class="w"> </span><span class="nn">np</span> </code></pre></div> <hr /> <h2 id="perform-exploratory-data-analysis">Perform exploratory data analysis</h2> <p>In this section, we first load the dataset into a <code>pandas</code> dataframe and then perform some basic exploratory data analysis (EDA).</p> <div class="codehilite"><pre><span></span><code><span class="n">arxiv_data</span> <span class="o">=</span> <span class="n">pd</span><span class="o">.</span><span class="n">read_csv</span><span class="p">(</span> <span class="s2">&quot;https://github.com/soumik12345/multi-label-text-classification/releases/download/v0.2/arxiv_data.csv&quot;</span> <span class="p">)</span> <span class="n">arxiv_data</span><span class="o">.</span><span class="n">head</span><span class="p">()</span> </code></pre></div> <div> <style scoped> .dataframe tbody tr th:only-of-type { vertical-align: middle; } <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>.dataframe tbody tr th { vertical-align: top; } .dataframe thead th { text-align: right; } </code></pre></div> </div> </style> <table border="1" class="dataframe"> <thead> <tr style="text-align: right;"> <th></th> <th>titles</th> <th>summaries</th> <th>terms</th> </tr> </thead> <tbody> <tr> <th>0</th> <td>Survey on Semantic Stereo Matching / Semantic ...</td> <td>Stereo matching is one of the widely used tech...</td> <td>['cs.CV', 'cs.LG']</td> </tr> <tr> <th>1</th> <td>FUTURE-AI: Guiding Principles and Consensus Re...</td> <td>The recent advancements in artificial intellig...</td> <td>['cs.CV', 'cs.AI', 'cs.LG']</td> </tr> <tr> <th>2</th> <td>Enforcing Mutual Consistency of Hard Regions f...</td> <td>In this paper, we proposed a novel mutual cons...</td> <td>['cs.CV', 'cs.AI']</td> </tr> <tr> <th>3</th> <td>Parameter Decoupling Strategy for Semi-supervi...</td> <td>Consistency training has proven to be an advan...</td> <td>['cs.CV']</td> </tr> <tr> <th>4</th> <td>Background-Foreground Segmentation for Interio...</td> <td>To ensure safety in automated driving, the cor...</td> <td>['cs.CV', 'cs.LG']</td> </tr> </tbody> </table> </div> <p>Our text features are present in the <code>summaries</code> column and their corresponding labels are in <code>terms</code>. As you can notice, there are multiple categories associated with a particular entry.</p> <div class="codehilite"><pre><span></span><code><span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;There are </span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">arxiv_data</span><span class="p">)</span><span class="si">}</span><span class="s2"> rows in the dataset.&quot;</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>There are 51774 rows in the dataset. </code></pre></div> </div> <p>Real-world data is noisy. One of the most commonly observed source of noise is data duplication. Here we notice that our initial dataset has got about 13k duplicate entries.</p> <div class="codehilite"><pre><span></span><code><span class="n">total_duplicate_titles</span> <span class="o">=</span> <span class="nb">sum</span><span class="p">(</span><span class="n">arxiv_data</span><span class="p">[</span><span class="s2">&quot;titles&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">duplicated</span><span class="p">())</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;There are </span><span class="si">{</span><span class="n">total_duplicate_titles</span><span class="si">}</span><span class="s2"> duplicate titles.&quot;</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>There are 12802 duplicate titles. </code></pre></div> </div> <p>Before proceeding further, we drop these entries.</p> <div class="codehilite"><pre><span></span><code><span class="n">arxiv_data</span> <span class="o">=</span> <span class="n">arxiv_data</span><span class="p">[</span><span class="o">~</span><span class="n">arxiv_data</span><span class="p">[</span><span class="s2">&quot;titles&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">duplicated</span><span class="p">()]</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;There are </span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">arxiv_data</span><span class="p">)</span><span class="si">}</span><span class="s2"> rows in the deduplicated dataset.&quot;</span><span class="p">)</span> <span class="c1"># There are some terms with occurrence as low as 1.</span> <span class="nb">print</span><span class="p">(</span><span class="nb">sum</span><span class="p">(</span><span class="n">arxiv_data</span><span class="p">[</span><span class="s2">&quot;terms&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">value_counts</span><span class="p">()</span> <span class="o">==</span> <span class="mi">1</span><span class="p">))</span> <span class="c1"># How many unique terms?</span> <span class="nb">print</span><span class="p">(</span><span class="n">arxiv_data</span><span class="p">[</span><span class="s2">&quot;terms&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">nunique</span><span class="p">())</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>There are 38972 rows in the deduplicated dataset. 2321 3157 </code></pre></div> </div> <p>As observed above, out of 3,157 unique combinations of <code>terms</code>, 2,321 entries have the lowest occurrence. To prepare our train, validation, and test sets with <a href="https://en.wikipedia.org/wiki/Stratified_sampling">stratification</a>, we need to drop these terms.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Filtering the rare terms.</span> <span class="n">arxiv_data_filtered</span> <span class="o">=</span> <span class="n">arxiv_data</span><span class="o">.</span><span class="n">groupby</span><span class="p">(</span><span class="s2">&quot;terms&quot;</span><span class="p">)</span><span class="o">.</span><span class="n">filter</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="nb">len</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">)</span> <span class="n">arxiv_data_filtered</span><span class="o">.</span><span class="n">shape</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>(36651, 3) </code></pre></div> </div> <hr /> <h2 id="convert-the-string-labels-to-lists-of-strings">Convert the string labels to lists of strings</h2> <p>The initial labels are represented as raw strings. Here we make them <code>List[str]</code> for a more compact representation.</p> <div class="codehilite"><pre><span></span><code><span class="n">arxiv_data_filtered</span><span class="p">[</span><span class="s2">&quot;terms&quot;</span><span class="p">]</span> <span class="o">=</span> <span class="n">arxiv_data_filtered</span><span class="p">[</span><span class="s2">&quot;terms&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span> <span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">literal_eval</span><span class="p">(</span><span class="n">x</span><span class="p">)</span> <span class="p">)</span> <span class="n">arxiv_data_filtered</span><span class="p">[</span><span class="s2">&quot;terms&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">values</span><span class="p">[:</span><span class="mi">5</span><span class="p">]</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>array([list([&#39;cs.CV&#39;, &#39;cs.LG&#39;]), list([&#39;cs.CV&#39;, &#39;cs.AI&#39;, &#39;cs.LG&#39;]), list([&#39;cs.CV&#39;, &#39;cs.AI&#39;]), list([&#39;cs.CV&#39;]), list([&#39;cs.CV&#39;, &#39;cs.LG&#39;])], dtype=object) </code></pre></div> </div> <hr /> <h2 id="use-stratified-splits-because-of-class-imbalance">Use stratified splits because of class imbalance</h2> <p>The dataset has a <a href="https://developers.google.com/machine-learning/glossary/#class-imbalanced-dataset">class imbalance problem</a>. So, to have a fair evaluation result, we need to ensure the datasets are sampled with stratification. To know more about different strategies to deal with the class imbalance problem, you can follow <a href="https://www.tensorflow.org/tutorials/structured_data/imbalanced_data">this tutorial</a>. For an end-to-end demonstration of classification with imbablanced data, refer to <a href="https://keras.io/examples/structured_data/imbalanced_classification/">Imbalanced classification: credit card fraud detection</a>.</p> <div class="codehilite"><pre><span></span><code><span class="n">test_split</span> <span class="o">=</span> <span class="mf">0.1</span> <span class="c1"># Initial train and test split.</span> <span class="n">train_df</span><span class="p">,</span> <span class="n">test_df</span> <span class="o">=</span> <span class="n">train_test_split</span><span class="p">(</span> <span class="n">arxiv_data_filtered</span><span class="p">,</span> <span class="n">test_size</span><span class="o">=</span><span class="n">test_split</span><span class="p">,</span> <span class="n">stratify</span><span class="o">=</span><span class="n">arxiv_data_filtered</span><span class="p">[</span><span class="s2">&quot;terms&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">values</span><span class="p">,</span> <span class="p">)</span> <span class="c1"># Splitting the test set further into validation</span> <span class="c1"># and new test sets.</span> <span class="n">val_df</span> <span class="o">=</span> <span class="n">test_df</span><span class="o">.</span><span class="n">sample</span><span class="p">(</span><span class="n">frac</span><span class="o">=</span><span class="mf">0.5</span><span class="p">)</span> <span class="n">test_df</span><span class="o">.</span><span class="n">drop</span><span class="p">(</span><span class="n">val_df</span><span class="o">.</span><span class="n">index</span><span class="p">,</span> <span class="n">inplace</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Number of rows in training set: </span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">train_df</span><span class="p">)</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Number of rows in validation set: </span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">val_df</span><span class="p">)</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Number of rows in test set: </span><span class="si">{</span><span class="nb">len</span><span class="p">(</span><span class="n">test_df</span><span class="p">)</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Number of rows in training set: 32985 Number of rows in validation set: 1833 Number of rows in test set: 1833 </code></pre></div> </div> <hr /> <h2 id="multilabel-binarization">Multi-label binarization</h2> <p>Now we preprocess our labels using the <a href="https://keras.io/api/layers/preprocessing_layers/categorical/string_lookup"><code>StringLookup</code></a> layer.</p> <div class="codehilite"><pre><span></span><code><span class="n">terms</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">ragged</span><span class="o">.</span><span class="n">constant</span><span class="p">(</span><span class="n">train_df</span><span class="p">[</span><span class="s2">&quot;terms&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">values</span><span class="p">)</span> <span class="n">lookup</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">StringLookup</span><span class="p">(</span><span class="n">output_mode</span><span class="o">=</span><span class="s2">&quot;multi_hot&quot;</span><span class="p">)</span> <span class="n">lookup</span><span class="o">.</span><span class="n">adapt</span><span class="p">(</span><span class="n">terms</span><span class="p">)</span> <span class="n">vocab</span> <span class="o">=</span> <span class="n">lookup</span><span class="o">.</span><span class="n">get_vocabulary</span><span class="p">()</span> <span class="k">def</span><span class="w"> </span><span class="nf">invert_multi_hot</span><span class="p">(</span><span class="n">encoded_labels</span><span class="p">):</span> <span class="w"> </span><span class="sd">&quot;&quot;&quot;Reverse a single multi-hot encoded label to a tuple of vocab terms.&quot;&quot;&quot;</span> <span class="n">hot_indices</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">argwhere</span><span class="p">(</span><span class="n">encoded_labels</span> <span class="o">==</span> <span class="mf">1.0</span><span class="p">)[</span><span class="o">...</span><span class="p">,</span> <span class="mi">0</span><span class="p">]</span> <span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">take</span><span class="p">(</span><span class="n">vocab</span><span class="p">,</span> <span class="n">hot_indices</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot;Vocabulary:</span><span class="se">\n</span><span class="s2">&quot;</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="n">vocab</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Vocabulary: </code></pre></div> </div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>[&#39;[UNK]&#39;, &#39;cs.CV&#39;, &#39;cs.LG&#39;, &#39;stat.ML&#39;, &#39;cs.AI&#39;, &#39;eess.IV&#39;, &#39;cs.RO&#39;, &#39;cs.CL&#39;, &#39;cs.NE&#39;, &#39;cs.CR&#39;, &#39;math.OC&#39;, &#39;eess.SP&#39;, &#39;cs.GR&#39;, &#39;cs.SI&#39;, &#39;cs.MM&#39;, &#39;cs.SY&#39;, &#39;cs.IR&#39;, &#39;cs.MA&#39;, &#39;eess.SY&#39;, &#39;cs.HC&#39;, &#39;math.IT&#39;, &#39;cs.IT&#39;, &#39;cs.DC&#39;, &#39;cs.CY&#39;, &#39;stat.AP&#39;, &#39;stat.TH&#39;, &#39;math.ST&#39;, &#39;stat.ME&#39;, &#39;eess.AS&#39;, &#39;cs.SD&#39;, &#39;q-bio.QM&#39;, &#39;q-bio.NC&#39;, &#39;cs.DS&#39;, &#39;cs.GT&#39;, &#39;cs.CG&#39;, &#39;cs.SE&#39;, &#39;cs.NI&#39;, &#39;I.2.6&#39;, &#39;stat.CO&#39;, &#39;math.NA&#39;, &#39;cs.NA&#39;, &#39;physics.chem-ph&#39;, &#39;cs.DB&#39;, &#39;q-bio.BM&#39;, &#39;cs.PL&#39;, &#39;cs.LO&#39;, &#39;cond-mat.dis-nn&#39;, &#39;68T45&#39;, &#39;math.PR&#39;, &#39;physics.comp-ph&#39;, &#39;I.2.10&#39;, &#39;cs.CE&#39;, &#39;cs.AR&#39;, &#39;q-fin.ST&#39;, &#39;cond-mat.stat-mech&#39;, &#39;68T05&#39;, &#39;quant-ph&#39;, &#39;math.DS&#39;, &#39;physics.data-an&#39;, &#39;cs.CC&#39;, &#39;I.4.6&#39;, &#39;physics.soc-ph&#39;, &#39;physics.ao-ph&#39;, &#39;cs.DM&#39;, &#39;econ.EM&#39;, &#39;q-bio.GN&#39;, &#39;physics.med-ph&#39;, &#39;astro-ph.IM&#39;, &#39;I.4.8&#39;, &#39;math.AT&#39;, &#39;cs.PF&#39;, &#39;cs.FL&#39;, &#39;I.4&#39;, &#39;q-fin.TR&#39;, &#39;I.5.4&#39;, &#39;I.2&#39;, &#39;68U10&#39;, &#39;hep-ex&#39;, &#39;cond-mat.mtrl-sci&#39;, &#39;68T10&#39;, &#39;physics.optics&#39;, &#39;physics.geo-ph&#39;, &#39;physics.flu-dyn&#39;, &#39;math.CO&#39;, &#39;math.AP&#39;, &#39;I.4; I.5&#39;, &#39;I.4.9&#39;, &#39;I.2.6; I.2.8&#39;, &#39;68T01&#39;, &#39;65D19&#39;, &#39;q-fin.CP&#39;, &#39;nlin.CD&#39;, &#39;cs.MS&#39;, &#39;I.2.6; I.5.1&#39;, &#39;I.2.10; I.4; I.5&#39;, &#39;I.2.0; I.2.6&#39;, &#39;68T07&#39;, &#39;q-fin.GN&#39;, &#39;cs.SC&#39;, &#39;cs.ET&#39;, &#39;K.3.2&#39;, &#39;I.2.8&#39;, &#39;68U01&#39;, &#39;68T30&#39;, &#39;q-fin.EC&#39;, &#39;q-bio.MN&#39;, &#39;econ.GN&#39;, &#39;I.4.9; I.5.4&#39;, &#39;I.4.5&#39;, &#39;I.2; I.5&#39;, &#39;I.2; I.4; I.5&#39;, &#39;I.2.6; I.2.7&#39;, &#39;I.2.10; I.4.8&#39;, &#39;68T99&#39;, &#39;68Q32&#39;, &#39;68&#39;, &#39;62H30&#39;, &#39;q-fin.RM&#39;, &#39;q-fin.PM&#39;, &#39;q-bio.TO&#39;, &#39;q-bio.OT&#39;, &#39;physics.bio-ph&#39;, &#39;nlin.AO&#39;, &#39;math.LO&#39;, &#39;math.FA&#39;, &#39;hep-ph&#39;, &#39;cond-mat.soft&#39;, &#39;I.4.6; I.4.8&#39;, &#39;I.4.4&#39;, &#39;I.4.3&#39;, &#39;I.4.0&#39;, &#39;I.2; J.2&#39;, &#39;I.2; I.2.6; I.2.7&#39;, &#39;I.2.7&#39;, &#39;I.2.6; I.5.4&#39;, &#39;I.2.6; I.2.9&#39;, &#39;I.2.6; I.2.7; H.3.1; H.3.3&#39;, &#39;I.2.6; I.2.10&#39;, &#39;I.2.6, I.5.4&#39;, &#39;I.2.1; J.3&#39;, &#39;I.2.10; I.5.1; I.4.8&#39;, &#39;I.2.10; I.4.8; I.5.4&#39;, &#39;I.2.10; I.2.6&#39;, &#39;I.2.1&#39;, &#39;H.3.1; I.2.6; I.2.7&#39;, &#39;H.3.1; H.3.3; I.2.6; I.2.7&#39;, &#39;G.3&#39;, &#39;F.2.2; I.2.7&#39;, &#39;E.5; E.4; E.2; H.1.1; F.1.1; F.1.3&#39;, &#39;68Txx&#39;, &#39;62H99&#39;, &#39;62H35&#39;, &#39;14J60 (Primary) 14F05, 14J26 (Secondary)&#39;] </code></pre></div> </div> <p>Here we are separating the individual unique classes available from the label pool and then using this information to represent a given label set with 0's and 1's. Below is an example.</p> <div class="codehilite"><pre><span></span><code><span class="n">sample_label</span> <span class="o">=</span> <span class="n">train_df</span><span class="p">[</span><span class="s2">&quot;terms&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">iloc</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Original label: </span><span class="si">{</span><span class="n">sample_label</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span> <span class="n">label_binarized</span> <span class="o">=</span> <span class="n">lookup</span><span class="p">([</span><span class="n">sample_label</span><span class="p">])</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Label-binarized representation: </span><span class="si">{</span><span class="n">label_binarized</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Original label: [&#39;cs.LG&#39;, &#39;cs.CV&#39;, &#39;eess.IV&#39;] Label-binarized representation: [[0. 1. 1. 0. 0. 1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]] </code></pre></div> </div> <hr /> <h2 id="tfdatadataset">Data preprocessing and <a href="https://www.tensorflow.org/api_docs/python/tf/data/Dataset"><code>tf.data.Dataset</code></a> objects</h2> <p>We first get percentile estimates of the sequence lengths. The purpose will be clear in a moment.</p> <div class="codehilite"><pre><span></span><code><span class="n">train_df</span><span class="p">[</span><span class="s2">&quot;summaries&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="nb">len</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">split</span><span class="p">(</span><span class="s2">&quot; &quot;</span><span class="p">)))</span><span class="o">.</span><span class="n">describe</span><span class="p">()</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>count 32985.000000 mean 156.497105 std 41.528225 min 5.000000 25% 128.000000 50% 154.000000 75% 183.000000 max 462.000000 Name: summaries, dtype: float64 </code></pre></div> </div> <p>Notice that 50% of the abstracts have a length of 154 (you may get a different number based on the split). So, any number close to that value is a good enough approximate for the maximum sequence length.</p> <p>Now, we implement utilities to prepare our datasets.</p> <div class="codehilite"><pre><span></span><code><span class="n">max_seqlen</span> <span class="o">=</span> <span class="mi">150</span> <span class="n">batch_size</span> <span class="o">=</span> <span class="mi">128</span> <span class="n">padding_token</span> <span class="o">=</span> <span class="s2">&quot;&lt;pad&gt;&quot;</span> <span class="n">auto</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">AUTOTUNE</span> <span class="k">def</span><span class="w"> </span><span class="nf">make_dataset</span><span class="p">(</span><span class="n">dataframe</span><span class="p">,</span> <span class="n">is_train</span><span class="o">=</span><span class="kc">True</span><span class="p">):</span> <span class="n">labels</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">ragged</span><span class="o">.</span><span class="n">constant</span><span class="p">(</span><span class="n">dataframe</span><span class="p">[</span><span class="s2">&quot;terms&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">values</span><span class="p">)</span> <span class="n">label_binarized</span> <span class="o">=</span> <span class="n">lookup</span><span class="p">(</span><span class="n">labels</span><span class="p">)</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">Dataset</span><span class="o">.</span><span class="n">from_tensor_slices</span><span class="p">(</span> <span class="p">(</span><span class="n">dataframe</span><span class="p">[</span><span class="s2">&quot;summaries&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">values</span><span class="p">,</span> <span class="n">label_binarized</span><span class="p">)</span> <span class="p">)</span> <span class="n">dataset</span> <span class="o">=</span> <span class="n">dataset</span><span class="o">.</span><span class="n">shuffle</span><span class="p">(</span><span class="n">batch_size</span> <span class="o">*</span> <span class="mi">10</span><span class="p">)</span> <span class="k">if</span> <span class="n">is_train</span> <span class="k">else</span> <span class="n">dataset</span> <span class="k">return</span> <span class="n">dataset</span><span class="o">.</span><span class="n">batch</span><span class="p">(</span><span class="n">batch_size</span><span class="p">)</span> </code></pre></div> <p>Now we can prepare the <a href="https://www.tensorflow.org/api_docs/python/tf/data/Dataset"><code>tf.data.Dataset</code></a> objects.</p> <div class="codehilite"><pre><span></span><code><span class="n">train_dataset</span> <span class="o">=</span> <span class="n">make_dataset</span><span class="p">(</span><span class="n">train_df</span><span class="p">,</span> <span class="n">is_train</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span> <span class="n">validation_dataset</span> <span class="o">=</span> <span class="n">make_dataset</span><span class="p">(</span><span class="n">val_df</span><span class="p">,</span> <span class="n">is_train</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> <span class="n">test_dataset</span> <span class="o">=</span> <span class="n">make_dataset</span><span class="p">(</span><span class="n">test_df</span><span class="p">,</span> <span class="n">is_train</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> </code></pre></div> <hr /> <h2 id="dataset-preview">Dataset preview</h2> <div class="codehilite"><pre><span></span><code><span class="n">text_batch</span><span class="p">,</span> <span class="n">label_batch</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="nb">iter</span><span class="p">(</span><span class="n">train_dataset</span><span class="p">))</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">text</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">text_batch</span><span class="p">[:</span><span class="mi">5</span><span class="p">]):</span> <span class="n">label</span> <span class="o">=</span> <span class="n">label_batch</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">numpy</span><span class="p">()[</span><span class="kc">None</span><span class="p">,</span> <span class="o">...</span><span class="p">]</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Abstract: </span><span class="si">{</span><span class="n">text</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Label(s): </span><span class="si">{</span><span class="n">invert_multi_hot</span><span class="p">(</span><span class="n">label</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot; &quot;</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Abstract: b&quot;In this paper we show how using satellite images can improve the accuracy of\nhousing price estimation models. Using Los Angeles County&#39;s property assessment\ndataset, by transferring learning from an Inception-v3 model pretrained on\nImageNet, we could achieve an improvement of ~10% in R-squared score compared\nto two baseline models that only use non-image features of the house.&quot; Label(s): [&#39;cs.LG&#39; &#39;stat.ML&#39;] Abstract: b&#39;Learning from data streams is an increasingly important topic in data mining,\nmachine learning, and artificial intelligence in general. A major focus in the\ndata stream literature is on designing methods that can deal with concept\ndrift, a challenge where the generating distribution changes over time. A\ngeneral assumption in most of this literature is that instances are\nindependently distributed in the stream. In this work we show that, in the\ncontext of concept drift, this assumption is contradictory, and that the\npresence of concept drift necessarily implies temporal dependence; and thus\nsome form of time series. This has important implications on model design and\ndeployment. We explore and highlight the these implications, and show that\nHoeffding-tree based ensembles, which are very popular for learning in streams,\nare not naturally suited to learning \\emph{within} drift; and can perform in\nthis scenario only at significant computational cost of destructive adaptation.\nOn the other hand, we develop and parameterize gradient-descent methods and\ndemonstrate how they can perform \\emph{continuous} adaptation with no explicit\ndrift-detection mechanism, offering major advantages in terms of accuracy and\nefficiency. As a consequence of our theoretical discussion and empirical\nobservations, we outline a number of recommendations for deploying methods in\nconcept-drifting streams.&#39; Label(s): [&#39;cs.LG&#39; &#39;stat.ML&#39;] Abstract: b&quot;As reinforcement learning (RL) achieves more success in solving complex\ntasks, more care is needed to ensure that RL research is reproducible and that\nalgorithms herein can be compared easily and fairly with minimal bias. RL\nresults are, however, notoriously hard to reproduce due to the algorithms&#39;\nintrinsic variance, the environments&#39; stochasticity, and numerous (potentially\nunreported) hyper-parameters. In this work we investigate the many issues\nleading to irreproducible research and how to manage those. We further show how\nto utilise a rigorous and standardised evaluation approach for easing the\nprocess of documentation, evaluation and fair comparison of different\nalgorithms, where we emphasise the importance of choosing the right measurement\nmetrics and conducting proper statistics on the results, for unbiased reporting\nof the results.&quot; Label(s): [&#39;cs.LG&#39; &#39;stat.ML&#39; &#39;cs.AI&#39; &#39;cs.RO&#39;] Abstract: b&#39;Estimating dense correspondences between images is a long-standing image\nunder-standing task. Recent works introduce convolutional neural networks\n(CNNs) to extract high-level feature maps and find correspondences through\nfeature matching. However,high-level feature maps are in low spatial resolution\nand therefore insufficient to provide accurate and fine-grained features to\ndistinguish intra-class variations for correspondence matching. To address this\nproblem, we generate robust features by dynamically selecting features at\ndifferent scales. To resolve two critical issues in feature selection,i.e.,how\nmany and which scales of features to be selected, we frame the feature\nselection process as a sequential Markov decision-making process (MDP) and\nintroduce an optimal selection strategy using reinforcement learning (RL). We\ndefine an RL environment for image matching in which each individual action\neither requires new features or terminates the selection episode by referring a\nmatching score. Deep neural networks are incorporated into our method and\ntrained for decision making. Experimental results show that our method achieves\ncomparable/superior performance with state-of-the-art methods on three\nbenchmarks, demonstrating the effectiveness of our feature selection strategy.&#39; Label(s): [&#39;cs.CV&#39;] Abstract: b&#39;Dense reconstructions often contain errors that prior work has so far\nminimised using high quality sensors and regularising the output. Nevertheless,\nerrors still persist. This paper proposes a machine learning technique to\nidentify errors in three dimensional (3D) meshes. Beyond simply identifying\nerrors, our method quantifies both the magnitude and the direction of depth\nestimate errors when viewing the scene. This enables us to improve the\nreconstruction accuracy.\n We train a suitably deep network architecture with two 3D meshes: a\nhigh-quality laser reconstruction, and a lower quality stereo image\nreconstruction. The network predicts the amount of error in the lower quality\nreconstruction with respect to the high-quality one, having only view the\nformer through its input. We evaluate our approach by correcting\ntwo-dimensional (2D) inverse-depth images extracted from the 3D model, and show\nthat our method improves the quality of these depth reconstructions by up to a\nrelative 10% RMSE.&#39; Label(s): [&#39;cs.CV&#39; &#39;cs.RO&#39;] </code></pre></div> </div> <hr /> <h2 id="vectorization">Vectorization</h2> <p>Before we feed the data to our model, we need to vectorize it (represent it in a numerical form). For that purpose, we will use the <a href="https://keras.io/api/layers/preprocessing_layers/text/text_vectorization"><code>TextVectorization</code> layer</a>. It can operate as a part of your main model so that the model is excluded from the core preprocessing logic. This greatly reduces the chances of training / serving skew during inference.</p> <p>We first calculate the number of unique words present in the abstracts.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Source: https://stackoverflow.com/a/18937309/7636462</span> <span class="n">vocabulary</span> <span class="o">=</span> <span class="nb">set</span><span class="p">()</span> <span class="n">train_df</span><span class="p">[</span><span class="s2">&quot;summaries&quot;</span><span class="p">]</span><span class="o">.</span><span class="n">str</span><span class="o">.</span><span class="n">lower</span><span class="p">()</span><span class="o">.</span><span class="n">str</span><span class="o">.</span><span class="n">split</span><span class="p">()</span><span class="o">.</span><span class="n">apply</span><span class="p">(</span><span class="n">vocabulary</span><span class="o">.</span><span class="n">update</span><span class="p">)</span> <span class="n">vocabulary_size</span> <span class="o">=</span> <span class="nb">len</span><span class="p">(</span><span class="n">vocabulary</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="n">vocabulary_size</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>153338 </code></pre></div> </div> <p>We now create our vectorization layer and <code>map()</code> to the <a href="https://www.tensorflow.org/api_docs/python/tf/data/Dataset"><code>tf.data.Dataset</code></a>s created earlier.</p> <div class="codehilite"><pre><span></span><code><span class="n">text_vectorizer</span> <span class="o">=</span> <span class="n">layers</span><span class="o">.</span><span class="n">TextVectorization</span><span class="p">(</span> <span class="n">max_tokens</span><span class="o">=</span><span class="n">vocabulary_size</span><span class="p">,</span> <span class="n">ngrams</span><span class="o">=</span><span class="mi">2</span><span class="p">,</span> <span class="n">output_mode</span><span class="o">=</span><span class="s2">&quot;tf_idf&quot;</span> <span class="p">)</span> <span class="c1"># `TextVectorization` layer needs to be adapted as per the vocabulary from our</span> <span class="c1"># training set.</span> <span class="k">with</span> <span class="n">tf</span><span class="o">.</span><span class="n">device</span><span class="p">(</span><span class="s2">&quot;/CPU:0&quot;</span><span class="p">):</span> <span class="n">text_vectorizer</span><span class="o">.</span><span class="n">adapt</span><span class="p">(</span><span class="n">train_dataset</span><span class="o">.</span><span class="n">map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">text</span><span class="p">,</span> <span class="n">label</span><span class="p">:</span> <span class="n">text</span><span class="p">))</span> <span class="n">train_dataset</span> <span class="o">=</span> <span class="n">train_dataset</span><span class="o">.</span><span class="n">map</span><span class="p">(</span> <span class="k">lambda</span> <span class="n">text</span><span class="p">,</span> <span class="n">label</span><span class="p">:</span> <span class="p">(</span><span class="n">text_vectorizer</span><span class="p">(</span><span class="n">text</span><span class="p">),</span> <span class="n">label</span><span class="p">),</span> <span class="n">num_parallel_calls</span><span class="o">=</span><span class="n">auto</span> <span class="p">)</span><span class="o">.</span><span class="n">prefetch</span><span class="p">(</span><span class="n">auto</span><span class="p">)</span> <span class="n">validation_dataset</span> <span class="o">=</span> <span class="n">validation_dataset</span><span class="o">.</span><span class="n">map</span><span class="p">(</span> <span class="k">lambda</span> <span class="n">text</span><span class="p">,</span> <span class="n">label</span><span class="p">:</span> <span class="p">(</span><span class="n">text_vectorizer</span><span class="p">(</span><span class="n">text</span><span class="p">),</span> <span class="n">label</span><span class="p">),</span> <span class="n">num_parallel_calls</span><span class="o">=</span><span class="n">auto</span> <span class="p">)</span><span class="o">.</span><span class="n">prefetch</span><span class="p">(</span><span class="n">auto</span><span class="p">)</span> <span class="n">test_dataset</span> <span class="o">=</span> <span class="n">test_dataset</span><span class="o">.</span><span class="n">map</span><span class="p">(</span> <span class="k">lambda</span> <span class="n">text</span><span class="p">,</span> <span class="n">label</span><span class="p">:</span> <span class="p">(</span><span class="n">text_vectorizer</span><span class="p">(</span><span class="n">text</span><span class="p">),</span> <span class="n">label</span><span class="p">),</span> <span class="n">num_parallel_calls</span><span class="o">=</span><span class="n">auto</span> <span class="p">)</span><span class="o">.</span><span class="n">prefetch</span><span class="p">(</span><span class="n">auto</span><span class="p">)</span> </code></pre></div> <p>A batch of raw text will first go through the <code>TextVectorization</code> layer and it will generate their integer representations. Internally, the <code>TextVectorization</code> layer will first create bi-grams out of the sequences and then represent them using <a href="https://wikipedia.org/wiki/Tf%E2%80%93idf">TF-IDF</a>. The output representations will then be passed to the shallow model responsible for text classification.</p> <p>To learn more about other possible configurations with <code>TextVectorizer</code>, please consult the <a href="https://keras.io/api/layers/preprocessing_layers/text/text_vectorization">official documentation</a>.</p> <p><strong>Note</strong>: Setting the <code>max_tokens</code> argument to a pre-calculated vocabulary size is not a requirement.</p> <hr /> <h2 id="create-a-text-classification-model">Create a text classification model</h2> <p>We will keep our model simple &ndash; it will be a small stack of fully-connected layers with ReLU as the non-linearity.</p> <div class="codehilite"><pre><span></span><code><span class="k">def</span><span class="w"> </span><span class="nf">make_model</span><span class="p">():</span> <span class="n">shallow_mlp_model</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">(</span> <span class="p">[</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">512</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">&quot;relu&quot;</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="mi">256</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s2">&quot;relu&quot;</span><span class="p">),</span> <span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">lookup</span><span class="o">.</span><span class="n">vocabulary_size</span><span class="p">(),</span> <span class="n">activation</span><span class="o">=</span><span class="s2">&quot;sigmoid&quot;</span><span class="p">),</span> <span class="p">]</span> <span class="c1"># More on why &quot;sigmoid&quot; has been used here in a moment.</span> <span class="p">)</span> <span class="k">return</span> <span class="n">shallow_mlp_model</span> </code></pre></div> <hr /> <h2 id="train-the-model">Train the model</h2> <p>We will train our model using the binary crossentropy loss. This is because the labels are not disjoint. For a given abstract, we may have multiple categories. So, we will divide the prediction task into a series of multiple binary classification problems. This is also why we kept the activation function of the classification layer in our model to sigmoid. Researchers have used other combinations of loss function and activation function as well. For example, in <a href="https://arxiv.org/abs/1805.00932">Exploring the Limits of Weakly Supervised Pretraining</a>, Mahajan et al. used the softmax activation function and cross-entropy loss to train their models.</p> <p>There are several options of metrics that can be used in multi-label classification. To keep this code example narrow we decided to use the <a href="https://keras.io/api/metrics/accuracy_metrics/#binaryaccuracy-class">binary accuracy metric</a>. To see the explanation why this metric is used we refer to this <a href="https://github.com/keras-team/keras-io/pull/1133#issuecomment-1322736860">pull-request</a>. There are also other suitable metrics for multi-label classification, like <a href="https://www.tensorflow.org/addons/api_docs/python/tfa/metrics/F1Score">F1 Score</a> or <a href="https://www.tensorflow.org/addons/api_docs/python/tfa/metrics/HammingLoss">Hamming loss</a>.</p> <div class="codehilite"><pre><span></span><code><span class="n">epochs</span> <span class="o">=</span> <span class="mi">20</span> <span class="n">shallow_mlp_model</span> <span class="o">=</span> <span class="n">make_model</span><span class="p">()</span> <span class="n">shallow_mlp_model</span><span class="o">.</span><span class="n">compile</span><span class="p">(</span> <span class="n">loss</span><span class="o">=</span><span class="s2">&quot;binary_crossentropy&quot;</span><span class="p">,</span> <span class="n">optimizer</span><span class="o">=</span><span class="s2">&quot;adam&quot;</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="s2">&quot;binary_accuracy&quot;</span><span class="p">]</span> <span class="p">)</span> <span class="n">history</span> <span class="o">=</span> <span class="n">shallow_mlp_model</span><span class="o">.</span><span class="n">fit</span><span class="p">(</span> <span class="n">train_dataset</span><span class="p">,</span> <span class="n">validation_data</span><span class="o">=</span><span class="n">validation_dataset</span><span class="p">,</span> <span class="n">epochs</span><span class="o">=</span><span class="n">epochs</span> <span class="p">)</span> <span class="k">def</span><span class="w"> </span><span class="nf">plot_result</span><span class="p">(</span><span class="n">item</span><span class="p">):</span> <span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">history</span><span class="o">.</span><span class="n">history</span><span class="p">[</span><span class="n">item</span><span class="p">],</span> <span class="n">label</span><span class="o">=</span><span class="n">item</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">plot</span><span class="p">(</span><span class="n">history</span><span class="o">.</span><span class="n">history</span><span class="p">[</span><span class="s2">&quot;val_&quot;</span> <span class="o">+</span> <span class="n">item</span><span class="p">],</span> <span class="n">label</span><span class="o">=</span><span class="s2">&quot;val_&quot;</span> <span class="o">+</span> <span class="n">item</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s2">&quot;Epochs&quot;</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">ylabel</span><span class="p">(</span><span class="n">item</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">title</span><span class="p">(</span><span class="s2">&quot;Train and Validation </span><span class="si">{}</span><span class="s2"> Over Epochs&quot;</span><span class="o">.</span><span class="n">format</span><span class="p">(</span><span class="n">item</span><span class="p">),</span> <span class="n">fontsize</span><span class="o">=</span><span class="mi">14</span><span class="p">)</span> <span class="n">plt</span><span class="o">.</span><span class="n">legend</span><span class="p">()</span> <span class="n">plt</span><span class="o">.</span><span class="n">grid</span><span class="p">()</span> <span class="n">plt</span><span class="o">.</span><span class="n">show</span><span class="p">()</span> <span class="n">plot_result</span><span class="p">(</span><span class="s2">&quot;loss&quot;</span><span class="p">)</span> <span class="n">plot_result</span><span class="p">(</span><span class="s2">&quot;binary_accuracy&quot;</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>Epoch 1/20 258/258 [==============================] - 87s 332ms/step - loss: 0.0326 - binary_accuracy: 0.9893 - val_loss: 0.0189 - val_binary_accuracy: 0.9943 Epoch 2/20 258/258 [==============================] - 100s 387ms/step - loss: 0.0033 - binary_accuracy: 0.9990 - val_loss: 0.0271 - val_binary_accuracy: 0.9940 Epoch 3/20 258/258 [==============================] - 99s 384ms/step - loss: 7.8393e-04 - binary_accuracy: 0.9999 - val_loss: 0.0328 - val_binary_accuracy: 0.9939 Epoch 4/20 258/258 [==============================] - 109s 421ms/step - loss: 3.0132e-04 - binary_accuracy: 1.0000 - val_loss: 0.0366 - val_binary_accuracy: 0.9939 Epoch 5/20 258/258 [==============================] - 105s 405ms/step - loss: 1.6006e-04 - binary_accuracy: 1.0000 - val_loss: 0.0399 - val_binary_accuracy: 0.9939 Epoch 6/20 258/258 [==============================] - 107s 414ms/step - loss: 1.2400e-04 - binary_accuracy: 1.0000 - val_loss: 0.0412 - val_binary_accuracy: 0.9939 Epoch 7/20 258/258 [==============================] - 110s 425ms/step - loss: 7.7131e-05 - binary_accuracy: 1.0000 - val_loss: 0.0439 - val_binary_accuracy: 0.9940 Epoch 8/20 258/258 [==============================] - 105s 405ms/step - loss: 5.5611e-05 - binary_accuracy: 1.0000 - val_loss: 0.0446 - val_binary_accuracy: 0.9940 Epoch 9/20 258/258 [==============================] - 103s 397ms/step - loss: 4.5994e-05 - binary_accuracy: 1.0000 - val_loss: 0.0454 - val_binary_accuracy: 0.9940 Epoch 10/20 258/258 [==============================] - 105s 405ms/step - loss: 3.5126e-05 - binary_accuracy: 1.0000 - val_loss: 0.0472 - val_binary_accuracy: 0.9939 Epoch 11/20 258/258 [==============================] - 109s 422ms/step - loss: 2.9927e-05 - binary_accuracy: 1.0000 - val_loss: 0.0466 - val_binary_accuracy: 0.9940 Epoch 12/20 258/258 [==============================] - 133s 516ms/step - loss: 2.5748e-05 - binary_accuracy: 1.0000 - val_loss: 0.0484 - val_binary_accuracy: 0.9940 Epoch 13/20 258/258 [==============================] - 129s 497ms/step - loss: 4.3529e-05 - binary_accuracy: 1.0000 - val_loss: 0.0500 - val_binary_accuracy: 0.9940 Epoch 14/20 258/258 [==============================] - 158s 611ms/step - loss: 8.1068e-04 - binary_accuracy: 0.9998 - val_loss: 0.0377 - val_binary_accuracy: 0.9936 Epoch 15/20 258/258 [==============================] - 144s 558ms/step - loss: 0.0016 - binary_accuracy: 0.9995 - val_loss: 0.0418 - val_binary_accuracy: 0.9935 Epoch 16/20 258/258 [==============================] - 131s 506ms/step - loss: 0.0018 - binary_accuracy: 0.9995 - val_loss: 0.0479 - val_binary_accuracy: 0.9931 Epoch 17/20 258/258 [==============================] - 127s 491ms/step - loss: 0.0012 - binary_accuracy: 0.9997 - val_loss: 0.0521 - val_binary_accuracy: 0.9931 Epoch 18/20 258/258 [==============================] - 153s 594ms/step - loss: 6.3144e-04 - binary_accuracy: 0.9998 - val_loss: 0.0549 - val_binary_accuracy: 0.9934 Epoch 19/20 258/258 [==============================] - 142s 550ms/step - loss: 3.1753e-04 - binary_accuracy: 0.9999 - val_loss: 0.0589 - val_binary_accuracy: 0.9934 Epoch 20/20 258/258 [==============================] - 153s 594ms/step - loss: 2.0258e-04 - binary_accuracy: 1.0000 - val_loss: 0.0585 - val_binary_accuracy: 0.9933 </code></pre></div> </div> <p><img alt="png" src="/img/examples/nlp/multi_label_classification/multi_label_classification_38_1.png" /></p> <p><img alt="png" src="/img/examples/nlp/multi_label_classification/multi_label_classification_38_2.png" /></p> <p>While training, we notice an initial sharp fall in the loss followed by a gradual decay.</p> <h3 id="evaluate-the-model">Evaluate the model</h3> <div class="codehilite"><pre><span></span><code><span class="n">_</span><span class="p">,</span> <span class="n">binary_acc</span> <span class="o">=</span> <span class="n">shallow_mlp_model</span><span class="o">.</span><span class="n">evaluate</span><span class="p">(</span><span class="n">test_dataset</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Categorical accuracy on the test set: </span><span class="si">{</span><span class="nb">round</span><span class="p">(</span><span class="n">binary_acc</span><span class="w"> </span><span class="o">*</span><span class="w"> </span><span class="mi">100</span><span class="p">,</span><span class="w"> </span><span class="mi">2</span><span class="p">)</span><span class="si">}</span><span class="s2">%.&quot;</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>15/15 [==============================] - 3s 196ms/step - loss: 0.0580 - binary_accuracy: 0.9933 Categorical accuracy on the test set: 99.33%. </code></pre></div> </div> <p>The trained model gives us an evaluation accuracy of ~99%.</p> <hr /> <h2 id="inference">Inference</h2> <p>An important feature of the <a href="https://keras.io/api/layers/preprocessing_layers/">preprocessing layers provided by Keras</a> is that they can be included inside a <a href="https://www.tensorflow.org/api_docs/python/tf/keras/Model"><code>tf.keras.Model</code></a>. We will export an inference model by including the <code>text_vectorization</code> layer on top of <code>shallow_mlp_model</code>. This will allow our inference model to directly operate on raw strings.</p> <p><strong>Note</strong> that during training it is always preferable to use these preprocessing layers as a part of the data input pipeline rather than the model to avoid surfacing bottlenecks for the hardware accelerators. This also allows for asynchronous data processing.</p> <div class="codehilite"><pre><span></span><code><span class="c1"># Create a model for inference.</span> <span class="n">model_for_inference</span> <span class="o">=</span> <span class="n">keras</span><span class="o">.</span><span class="n">Sequential</span><span class="p">([</span><span class="n">text_vectorizer</span><span class="p">,</span> <span class="n">shallow_mlp_model</span><span class="p">])</span> <span class="c1"># Create a small dataset just for demoing inference.</span> <span class="n">inference_dataset</span> <span class="o">=</span> <span class="n">make_dataset</span><span class="p">(</span><span class="n">test_df</span><span class="o">.</span><span class="n">sample</span><span class="p">(</span><span class="mi">100</span><span class="p">),</span> <span class="n">is_train</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span> <span class="n">text_batch</span><span class="p">,</span> <span class="n">label_batch</span> <span class="o">=</span> <span class="nb">next</span><span class="p">(</span><span class="nb">iter</span><span class="p">(</span><span class="n">inference_dataset</span><span class="p">))</span> <span class="n">predicted_probabilities</span> <span class="o">=</span> <span class="n">model_for_inference</span><span class="o">.</span><span class="n">predict</span><span class="p">(</span><span class="n">text_batch</span><span class="p">)</span> <span class="c1"># Perform inference.</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">text</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">text_batch</span><span class="p">[:</span><span class="mi">5</span><span class="p">]):</span> <span class="n">label</span> <span class="o">=</span> <span class="n">label_batch</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">numpy</span><span class="p">()[</span><span class="kc">None</span><span class="p">,</span> <span class="o">...</span><span class="p">]</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Abstract: </span><span class="si">{</span><span class="n">text</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Label(s): </span><span class="si">{</span><span class="n">invert_multi_hot</span><span class="p">(</span><span class="n">label</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span> <span class="n">predicted_proba</span> <span class="o">=</span> <span class="p">[</span><span class="n">proba</span> <span class="k">for</span> <span class="n">proba</span> <span class="ow">in</span> <span class="n">predicted_probabilities</span><span class="p">[</span><span class="n">i</span><span class="p">]]</span> <span class="n">top_3_labels</span> <span class="o">=</span> <span class="p">[</span> <span class="n">x</span> <span class="k">for</span> <span class="n">_</span><span class="p">,</span> <span class="n">x</span> <span class="ow">in</span> <span class="nb">sorted</span><span class="p">(</span> <span class="nb">zip</span><span class="p">(</span><span class="n">predicted_probabilities</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">lookup</span><span class="o">.</span><span class="n">get_vocabulary</span><span class="p">()),</span> <span class="n">key</span><span class="o">=</span><span class="k">lambda</span> <span class="n">pair</span><span class="p">:</span> <span class="n">pair</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">reverse</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="p">)</span> <span class="p">][:</span><span class="mi">3</span><span class="p">]</span> <span class="nb">print</span><span class="p">(</span><span class="sa">f</span><span class="s2">&quot;Predicted Label(s): (</span><span class="si">{</span><span class="s1">&#39;, &#39;</span><span class="o">.</span><span class="n">join</span><span class="p">([</span><span class="n">label</span><span class="w"> </span><span class="k">for</span><span class="w"> </span><span class="n">label</span><span class="w"> </span><span class="ow">in</span><span class="w"> </span><span class="n">top_3_labels</span><span class="p">])</span><span class="si">}</span><span class="s2">)&quot;</span><span class="p">)</span> <span class="nb">print</span><span class="p">(</span><span class="s2">&quot; &quot;</span><span class="p">)</span> </code></pre></div> <div class="k-default-codeblock"> <div class="codehilite"><pre><span></span><code>4/4 [==============================] - 0s 62ms/step Abstract: b&#39;We investigate the training of sparse layers that use different parameters\nfor different inputs based on hashing in large Transformer models.\nSpecifically, we modify the feedforward layer to hash to different sets of\nweights depending on the current token, over all tokens in the sequence. We\nshow that this procedure either outperforms or is competitive with\nlearning-to-route mixture-of-expert methods such as Switch Transformers and\nBASE Layers, while requiring no routing parameters or extra terms in the\nobjective function such as a load balancing loss, and no sophisticated\nassignment algorithm. We study the performance of different hashing techniques,\nhash sizes and input features, and show that balanced and random hashes focused\non the most local features work best, compared to either learning clusters or\nusing longer-range context. We show our approach works well both on large\nlanguage modeling and dialogue tasks, and on downstream fine-tuning tasks.&#39; Label(s): [&#39;cs.LG&#39; &#39;cs.CL&#39;] Predicted Label(s): (cs.LG, cs.CL, stat.ML) Abstract: b&#39;We present the first method capable of photorealistically reconstructing\ndeformable scenes using photos/videos captured casually from mobile phones. Our\napproach augments neural radiance fields (NeRF) by optimizing an additional\ncontinuous volumetric deformation field that warps each observed point into a\ncanonical 5D NeRF. We observe that these NeRF-like deformation fields are prone\nto local minima, and propose a coarse-to-fine optimization method for\ncoordinate-based models that allows for more robust optimization. By adapting\nprinciples from geometry processing and physical simulation to NeRF-like\nmodels, we propose an elastic regularization of the deformation field that\nfurther improves robustness. We show that our method can turn casually captured\nselfie photos/videos into deformable NeRF models that allow for photorealistic\nrenderings of the subject from arbitrary viewpoints, which we dub &quot;nerfies.&quot; We\nevaluate our method by collecting time-synchronized data using a rig with two\nmobile phones, yielding train/validation images of the same pose at different\nviewpoints. We show that our method faithfully reconstructs non-rigidly\ndeforming scenes and reproduces unseen views with high fidelity.&#39; Label(s): [&#39;cs.CV&#39; &#39;cs.GR&#39;] Predicted Label(s): (cs.CV, cs.GR, cs.RO) Abstract: b&#39;We propose to jointly learn multi-view geometry and warping between views of\nthe same object instances for robust cross-view object detection. What makes\nmulti-view object instance detection difficult are strong changes in viewpoint,\nlighting conditions, high similarity of neighbouring objects, and strong\nvariability in scale. By turning object detection and instance\nre-identification in different views into a joint learning task, we are able to\nincorporate both image appearance and geometric soft constraints into a single,\nmulti-view detection process that is learnable end-to-end. We validate our\nmethod on a new, large data set of street-level panoramas of urban objects and\nshow superior performance compared to various baselines. Our contribution is\nthreefold: a large-scale, publicly available data set for multi-view instance\ndetection and re-identification; an annotation tool custom-tailored for\nmulti-view instance detection; and a novel, holistic multi-view instance\ndetection and re-identification method that jointly models geometry and\nappearance across views.&#39; Label(s): [&#39;cs.CV&#39; &#39;cs.LG&#39; &#39;stat.ML&#39;] Predicted Label(s): (cs.CV, cs.RO, cs.MM) Abstract: b&#39;Learning graph convolutional networks (GCNs) is an emerging field which aims\nat generalizing deep learning to arbitrary non-regular domains. Most of the\nexisting GCNs follow a neighborhood aggregation scheme, where the\nrepresentation of a node is recursively obtained by aggregating its neighboring\nnode representations using averaging or sorting operations. However, these\noperations are either ill-posed or weak to be discriminant or increase the\nnumber of training parameters and thereby the computational complexity and the\nrisk of overfitting. In this paper, we introduce a novel GCN framework that\nachieves spatial graph convolution in a reproducing kernel Hilbert space\n(RKHS). The latter makes it possible to design, via implicit kernel\nrepresentations, convolutional graph filters in a high dimensional and more\ndiscriminating space without increasing the number of training parameters. The\nparticularity of our GCN model also resides in its ability to achieve\nconvolutions without explicitly realigning nodes in the receptive fields of the\nlearned graph filters with those of the input graphs, thereby making\nconvolutions permutation agnostic and well defined. Experiments conducted on\nthe challenging task of skeleton-based action recognition show the superiority\nof the proposed method against different baselines as well as the related work.&#39; Label(s): [&#39;cs.CV&#39;] Predicted Label(s): (cs.LG, cs.CV, cs.NE) Abstract: b&#39;Recurrent meta reinforcement learning (meta-RL) agents are agents that employ\na recurrent neural network (RNN) for the purpose of &quot;learning a learning\nalgorithm&quot;. After being trained on a pre-specified task distribution, the\nlearned weights of the agent\&#39;s RNN are said to implement an efficient learning\nalgorithm through their activity dynamics, which allows the agent to quickly\nsolve new tasks sampled from the same distribution. However, due to the\nblack-box nature of these agents, the way in which they work is not yet fully\nunderstood. In this study, we shed light on the internal working mechanisms of\nthese agents by reformulating the meta-RL problem using the Partially\nObservable Markov Decision Process (POMDP) framework. We hypothesize that the\nlearned activity dynamics is acting as belief states for such agents. Several\nillustrative experiments suggest that this hypothesis is true, and that\nrecurrent meta-RL agents can be viewed as agents that learn to act optimally in\npartially observable environments consisting of multiple related tasks. This\nview helps in understanding their failure cases and some interesting\nmodel-based results reported in the literature.&#39; Label(s): [&#39;cs.LG&#39; &#39;cs.AI&#39;] Predicted Label(s): (stat.ML, cs.LG, cs.AI) </code></pre></div> </div> <p>The prediction results are not that great but not below the par for a simple model like ours. We can improve this performance with models that consider word order like LSTM or even those that use Transformers (<a href="https://arxiv.org/abs/1706.03762">Vaswani et al.</a>).</p> <hr /> <h2 id="acknowledgements">Acknowledgements</h2> <p>We would like to thank <a href="https://github.com/mattdangerw">Matt Watson</a> for helping us tackle the multi-label binarization part and inverse-transforming the processed labels to the original form.</p> <p>Thanks <a href="https://github.com/cumbalik">Cingis Kratochvil</a> for suggesting and extending this code example by the binary accuracy.</p> </div> <div class='k-outline'> <div class='k-outline-depth-1'> <a href='#largescale-multilabel-text-classification'>Large-scale multi-label text classification</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#introduction'>Introduction</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#imports'>Imports</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#perform-exploratory-data-analysis'>Perform exploratory data analysis</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#convert-the-string-labels-to-lists-of-strings'>Convert the string labels to lists of strings</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#use-stratified-splits-because-of-class-imbalance'>Use stratified splits because of class imbalance</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#multilabel-binarization'>Multi-label binarization</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#data-preprocessing-and-tfdatadataset-objects'>Data preprocessing and <code>tf.data.Dataset</code> objects</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#dataset-preview'>Dataset preview</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#vectorization'>Vectorization</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#create-a-text-classification-model'>Create a text classification model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#train-the-model'>Train the model</a> </div> <div class='k-outline-depth-3'> <a href='#evaluate-the-model'>Evaluate the model</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#inference'>Inference</a> </div> <div class='k-outline-depth-2'> ◆ <a href='#acknowledgements'>Acknowledgements</a> </div> </div> </div> </div> </div> </body> <footer style="float: left; width: 100%; padding: 1em; border-top: solid 1px #bbb;"> <a href="https://policies.google.com/terms">Terms</a> | <a href="https://policies.google.com/privacy">Privacy</a> </footer> </html>

Pages: 1 2 3 4 5 6 7 8 9 10